use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor};
use ferrotorch_nn::module::{Module, StateDict};
use ferrotorch_nn::parameter::Parameter;
use ferrotorch_nn::{Conv2d, GELU, GroupNorm, LayerNorm, Linear};
#[derive(Debug)]
pub struct Attention<T: Float> {
pub dim_head: usize,
pub heads: usize,
pub inner_dim: usize,
pub to_q: Linear<T>,
pub to_k: Linear<T>,
pub to_v: Linear<T>,
pub to_out_0: Linear<T>,
query_dim: usize,
kv_dim: usize,
scale: f64,
training: bool,
}
impl<T: Float> Attention<T> {
pub fn new(
query_dim: usize,
cross_attention_dim: Option<usize>,
heads: usize,
dim_head: usize,
bias: bool,
) -> FerrotorchResult<Self> {
let inner_dim = heads * dim_head;
let kv_dim = cross_attention_dim.unwrap_or(query_dim);
let to_q = Linear::<T>::new(query_dim, inner_dim, bias)?;
let to_k = Linear::<T>::new(kv_dim, inner_dim, bias)?;
let to_v = Linear::<T>::new(kv_dim, inner_dim, bias)?;
let to_out_0 = Linear::<T>::new(inner_dim, query_dim, true)?;
let scale = (dim_head as f64).sqrt().recip();
Ok(Self {
dim_head,
heads,
inner_dim,
to_q,
to_k,
to_v,
to_out_0,
query_dim,
kv_dim,
scale,
training: false,
})
}
pub fn forward_xattn(
&self,
hidden_states: &Tensor<T>,
encoder_hidden_states: Option<&Tensor<T>>,
) -> FerrotorchResult<Tensor<T>> {
if hidden_states.ndim() != 3 || hidden_states.shape()[2] != self.query_dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Attention::forward_xattn: expected hidden_states [B, N, {}], got {:?}",
self.query_dim,
hidden_states.shape()
),
});
}
let b = hidden_states.shape()[0];
let n = hidden_states.shape()[1];
let kv = encoder_hidden_states.unwrap_or(hidden_states);
if kv.ndim() != 3 || kv.shape()[0] != b || kv.shape()[2] != self.kv_dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Attention::forward_xattn: expected kv [B={b}, S, {}], got {:?}",
self.kv_dim,
kv.shape()
),
});
}
let s = kv.shape()[1];
let q = self.to_q.forward(hidden_states)?;
let k = self.to_k.forward(kv)?;
let v = self.to_v.forward(kv)?;
let h = self.heads;
let d = self.dim_head;
let q = q
.reshape_t(&[b as isize, n as isize, h as isize, d as isize])?
.transpose(1, 2)? .contiguous()?
.reshape_t(&[(b * h) as isize, n as isize, d as isize])?;
let k = k
.reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
.transpose(1, 2)? .contiguous()?
.reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
let v = v
.reshape_t(&[b as isize, s as isize, h as isize, d as isize])?
.transpose(1, 2)? .contiguous()?
.reshape_t(&[(b * h) as isize, s as isize, d as isize])?;
let k_t = k.transpose(1, 2)?.contiguous()?; let scores = q.bmm(&k_t)?; let scale_t = T::from(self.scale).ok_or_else(|| FerrotorchError::InvalidArgument {
message: "Attention::forward_xattn: failed to cast attention scale into Float".into(),
})?;
let scale_tensor = ferrotorch_core::scalar::<T>(scale_t)?;
let scores_scaled = ferrotorch_core::grad_fns::arithmetic::mul(&scores, &scale_tensor)?;
let probs = scores_scaled.softmax()?; let attended = probs.bmm(&v)?;
let attended = attended
.reshape_t(&[b as isize, h as isize, n as isize, d as isize])?
.transpose(1, 2)? .contiguous()?
.reshape_t(&[b as isize, n as isize, self.inner_dim as isize])?;
self.to_out_0.forward(&attended)
}
}
impl<T: Float> Module<T> for Attention<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward_xattn(input, None)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut o = Vec::new();
o.extend(self.to_q.parameters());
o.extend(self.to_k.parameters());
o.extend(self.to_v.parameters());
o.extend(self.to_out_0.parameters());
o
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut o = Vec::new();
o.extend(self.to_q.parameters_mut());
o.extend(self.to_k.parameters_mut());
o.extend(self.to_v.parameters_mut());
o.extend(self.to_out_0.parameters_mut());
o
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut o = Vec::new();
for (n, p) in self.to_q.named_parameters() {
o.push((format!("to_q.{n}"), p));
}
for (n, p) in self.to_k.named_parameters() {
o.push((format!("to_k.{n}"), p));
}
for (n, p) in self.to_v.named_parameters() {
o.push((format!("to_v.{n}"), p));
}
for (n, p) in self.to_out_0.named_parameters() {
o.push((format!("to_out.0.{n}"), p));
}
o
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
for k in state.keys() {
let ok = k.starts_with("to_q.")
|| k.starts_with("to_k.")
|| k.starts_with("to_v.")
|| k.starts_with("to_out.0.");
if !ok {
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in Attention state_dict: \"{k}\""),
});
}
}
}
self.to_q.load_state_dict(&extract("to_q"), strict)?;
self.to_k.load_state_dict(&extract("to_k"), strict)?;
self.to_v.load_state_dict(&extract("to_v"), strict)?;
self.to_out_0
.load_state_dict(&extract("to_out.0"), strict)?;
Ok(())
}
}
#[derive(Debug)]
pub struct FeedForward<T: Float> {
pub net_0_proj: Linear<T>,
pub net_2: Linear<T>,
activation: GELU,
dim_ff: usize,
training: bool,
}
impl<T: Float> FeedForward<T> {
pub fn new(dim: usize, mult: usize) -> FerrotorchResult<Self> {
let dim_ff = dim * mult;
let net_0_proj = Linear::<T>::new(dim, 2 * dim_ff, true)?;
let net_2 = Linear::<T>::new(dim_ff, dim, true)?;
Ok(Self {
net_0_proj,
net_2,
activation: GELU::new(),
dim_ff,
training: false,
})
}
}
impl<T: Float> Module<T> for FeedForward<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let proj = self.net_0_proj.forward(input)?;
let last = proj.ndim() - 1;
let parts = proj.chunk(2, last)?;
if parts.len() != 2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"FeedForward: chunk(2) returned {} parts (expected 2)",
parts.len()
),
});
}
let x = parts[0].contiguous()?;
let gate = parts[1].contiguous()?;
let gated = self.activation.forward(&gate)?;
let activated = ferrotorch_core::grad_fns::arithmetic::mul(&x, &gated)?;
self.net_2.forward(&activated)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut o = Vec::new();
o.extend(self.net_0_proj.parameters());
o.extend(self.net_2.parameters());
o
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut o = Vec::new();
o.extend(self.net_0_proj.parameters_mut());
o.extend(self.net_2.parameters_mut());
o
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut o = Vec::new();
for (n, p) in self.net_0_proj.named_parameters() {
o.push((format!("net.0.proj.{n}"), p));
}
for (n, p) in self.net_2.named_parameters() {
o.push((format!("net.2.{n}"), p));
}
o
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
for k in state.keys() {
let ok = k.starts_with("net.0.proj.") || k.starts_with("net.2.");
if !ok {
return Err(FerrotorchError::InvalidArgument {
message: format!("unexpected key in FeedForward state_dict: \"{k}\""),
});
}
}
}
self.net_0_proj
.load_state_dict(&extract("net.0.proj"), strict)?;
self.net_2.load_state_dict(&extract("net.2"), strict)?;
let _ = self.dim_ff;
Ok(())
}
}
#[derive(Debug)]
pub struct BasicTransformerBlock<T: Float> {
pub norm1: LayerNorm<T>,
pub attn1: Attention<T>,
pub norm2: LayerNorm<T>,
pub attn2: Attention<T>,
pub norm3: LayerNorm<T>,
pub ff: FeedForward<T>,
dim: usize,
training: bool,
}
impl<T: Float> BasicTransformerBlock<T> {
pub fn new(
dim: usize,
heads: usize,
dim_head: usize,
cross_attention_dim: usize,
) -> FerrotorchResult<Self> {
let norm1 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
let attn1 = Attention::<T>::new(dim, None, heads, dim_head, false)?;
let norm2 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
let attn2 = Attention::<T>::new(dim, Some(cross_attention_dim), heads, dim_head, false)?;
let norm3 = LayerNorm::<T>::new(vec![dim], 1e-5, true)?;
let ff = FeedForward::<T>::new(dim, 4)?;
Ok(Self {
norm1,
attn1,
norm2,
attn2,
norm3,
ff,
dim,
training: false,
})
}
pub fn forward_xattn(
&self,
x: &Tensor<T>,
encoder_hidden_states: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if x.ndim() != 3 || x.shape()[2] != self.dim {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"BasicTransformerBlock::forward: expected x [B, N, {}], got {:?}",
self.dim,
x.shape()
),
});
}
let h1 = self.norm1.forward(x)?;
let h1 = self.attn1.forward_xattn(&h1, None)?;
let x = ferrotorch_core::grad_fns::arithmetic::add(&h1, x)?;
let h2 = self.norm2.forward(&x)?;
let h2 = self.attn2.forward_xattn(&h2, Some(encoder_hidden_states))?;
let x = ferrotorch_core::grad_fns::arithmetic::add(&h2, &x)?;
let h3 = self.norm3.forward(&x)?;
let h3 = self.ff.forward(&h3)?;
ferrotorch_core::grad_fns::arithmetic::add(&h3, &x)
}
}
impl<T: Float> Module<T> for BasicTransformerBlock<T> {
fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
Err(FerrotorchError::InvalidArgument {
message: "BasicTransformerBlock::forward: cross-attn requires \
encoder_hidden_states — call forward_xattn instead"
.into(),
})
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut o = Vec::new();
o.extend(self.norm1.parameters());
o.extend(self.attn1.parameters());
o.extend(self.norm2.parameters());
o.extend(self.attn2.parameters());
o.extend(self.norm3.parameters());
o.extend(self.ff.parameters());
o
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut o = Vec::new();
o.extend(self.norm1.parameters_mut());
o.extend(self.attn1.parameters_mut());
o.extend(self.norm2.parameters_mut());
o.extend(self.attn2.parameters_mut());
o.extend(self.norm3.parameters_mut());
o.extend(self.ff.parameters_mut());
o
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut o = Vec::new();
for (n, p) in self.norm1.named_parameters() {
o.push((format!("norm1.{n}"), p));
}
for (n, p) in self.attn1.named_parameters() {
o.push((format!("attn1.{n}"), p));
}
for (n, p) in self.norm2.named_parameters() {
o.push((format!("norm2.{n}"), p));
}
for (n, p) in self.attn2.named_parameters() {
o.push((format!("attn2.{n}"), p));
}
for (n, p) in self.norm3.named_parameters() {
o.push((format!("norm3.{n}"), p));
}
for (n, p) in self.ff.named_parameters() {
o.push((format!("ff.{n}"), p));
}
o
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
for k in state.keys() {
let ok = k.starts_with("norm1.")
|| k.starts_with("attn1.")
|| k.starts_with("norm2.")
|| k.starts_with("attn2.")
|| k.starts_with("norm3.")
|| k.starts_with("ff.");
if !ok {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"unexpected key in BasicTransformerBlock state_dict: \"{k}\""
),
});
}
}
}
self.norm1.load_state_dict(&extract("norm1"), strict)?;
self.attn1.load_state_dict(&extract("attn1"), strict)?;
self.norm2.load_state_dict(&extract("norm2"), strict)?;
self.attn2.load_state_dict(&extract("attn2"), strict)?;
self.norm3.load_state_dict(&extract("norm3"), strict)?;
self.ff.load_state_dict(&extract("ff"), strict)?;
Ok(())
}
}
#[derive(Debug)]
pub struct Transformer2DModel<T: Float> {
pub norm: GroupNorm<T>,
pub proj_in: Conv2d<T>,
pub transformer_blocks: Vec<BasicTransformerBlock<T>>,
pub proj_out: Conv2d<T>,
channels: usize,
inner_dim: usize,
training: bool,
}
impl<T: Float> Transformer2DModel<T> {
pub fn new(
in_channels: usize,
heads: usize,
dim_head: usize,
num_layers: usize,
cross_attention_dim: usize,
norm_num_groups: usize,
) -> FerrotorchResult<Self> {
let inner_dim = heads * dim_head;
let norm = GroupNorm::<T>::new(norm_num_groups, in_channels, 1e-6, true)?;
let proj_in = Conv2d::<T>::new(in_channels, inner_dim, (1, 1), (1, 1), (0, 0), true)?;
let proj_out = Conv2d::<T>::new(inner_dim, in_channels, (1, 1), (1, 1), (0, 0), true)?;
let mut transformer_blocks = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
transformer_blocks.push(BasicTransformerBlock::<T>::new(
inner_dim,
heads,
dim_head,
cross_attention_dim,
)?);
}
Ok(Self {
norm,
proj_in,
transformer_blocks,
proj_out,
channels: in_channels,
inner_dim,
training: false,
})
}
pub fn forward_xattn(
&self,
x: &Tensor<T>,
encoder_hidden_states: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if x.ndim() != 4 || x.shape()[1] != self.channels {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"Transformer2DModel::forward: expected [B, {}, H, W], got {:?}",
self.channels,
x.shape()
),
});
}
let b = x.shape()[0];
let c = x.shape()[1];
let h = x.shape()[2];
let w = x.shape()[3];
let hw = h * w;
let residual = x.clone();
let mut hidden = self.norm.forward(x)?;
hidden = self.proj_in.forward(&hidden)?;
let mut hidden_seq = hidden
.reshape_t(&[b as isize, self.inner_dim as isize, hw as isize])?
.transpose(1, 2)?
.contiguous()?;
for block in &self.transformer_blocks {
hidden_seq = block.forward_xattn(&hidden_seq, encoder_hidden_states)?;
}
let hidden_back = hidden_seq
.transpose(1, 2)?
.reshape_t(&[b as isize, self.inner_dim as isize, h as isize, w as isize])?
.contiguous()?;
let out = self.proj_out.forward(&hidden_back)?;
let _ = c;
ferrotorch_core::grad_fns::arithmetic::add(&out, &residual)
}
}
impl<T: Float> Module<T> for Transformer2DModel<T> {
fn forward(&self, _input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
Err(FerrotorchError::InvalidArgument {
message: "Transformer2DModel::forward: cross-attn requires \
encoder_hidden_states — call forward_xattn instead"
.into(),
})
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut o = Vec::new();
o.extend(self.norm.parameters());
o.extend(self.proj_in.parameters());
for b in &self.transformer_blocks {
o.extend(b.parameters());
}
o.extend(self.proj_out.parameters());
o
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut o = Vec::new();
o.extend(self.norm.parameters_mut());
o.extend(self.proj_in.parameters_mut());
for b in &mut self.transformer_blocks {
o.extend(b.parameters_mut());
}
o.extend(self.proj_out.parameters_mut());
o
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut o = Vec::new();
for (n, p) in self.norm.named_parameters() {
o.push((format!("norm.{n}"), p));
}
for (n, p) in self.proj_in.named_parameters() {
o.push((format!("proj_in.{n}"), p));
}
for (i, b) in self.transformer_blocks.iter().enumerate() {
for (n, p) in b.named_parameters() {
o.push((format!("transformer_blocks.{i}.{n}"), p));
}
}
for (n, p) in self.proj_out.named_parameters() {
o.push((format!("proj_out.{n}"), p));
}
o
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
let extract = |prefix: &str| -> StateDict<T> {
let p = format!("{prefix}.");
state
.iter()
.filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
.collect()
};
if strict {
for k in state.keys() {
let ok = k.starts_with("norm.")
|| k.starts_with("proj_in.")
|| k.starts_with("transformer_blocks.")
|| k.starts_with("proj_out.");
if !ok {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"unexpected key in Transformer2DModel state_dict: \"{k}\""
),
});
}
}
}
self.norm.load_state_dict(&extract("norm"), strict)?;
self.proj_in.load_state_dict(&extract("proj_in"), strict)?;
for (i, b) in self.transformer_blocks.iter_mut().enumerate() {
b.load_state_dict(&extract(&format!("transformer_blocks.{i}")), strict)?;
}
self.proj_out
.load_state_dict(&extract("proj_out"), strict)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferrotorch_core::TensorStorage;
#[test]
fn attention_self_shape() {
let a = Attention::<f32>::new(16, None, 4, 4, false).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 5 * 16]),
vec![1, 5, 16],
false,
)
.unwrap();
let y = a.forward_xattn(&x, None).unwrap();
assert_eq!(y.shape(), &[1, 5, 16]);
}
#[test]
fn attention_cross_shape() {
let a = Attention::<f32>::new(16, Some(24), 4, 4, false).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 5 * 16]),
vec![1, 5, 16],
false,
)
.unwrap();
let ehs = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 7 * 24]),
vec![1, 7, 24],
false,
)
.unwrap();
let y = a.forward_xattn(&x, Some(&ehs)).unwrap();
assert_eq!(y.shape(), &[1, 5, 16]);
}
#[test]
fn feedforward_shape_and_keys() {
let ff = FeedForward::<f32>::new(16, 2).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 5 * 16]),
vec![1, 5, 16],
false,
)
.unwrap();
let y = ff.forward(&x).unwrap();
assert_eq!(y.shape(), &[1, 5, 16]);
let names: Vec<String> = ff.named_parameters().into_iter().map(|(n, _)| n).collect();
for k in ["net.0.proj.weight", "net.0.proj.bias", "net.2.weight", "net.2.bias"] {
assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
}
}
#[test]
fn basic_transformer_block_shape() {
let blk = BasicTransformerBlock::<f32>::new(16, 4, 4, 24).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 5 * 16]),
vec![1, 5, 16],
false,
)
.unwrap();
let ehs = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 7 * 24]),
vec![1, 7, 24],
false,
)
.unwrap();
let y = blk.forward_xattn(&x, &ehs).unwrap();
assert_eq!(y.shape(), &[1, 5, 16]);
}
#[test]
fn transformer_2d_shape() {
let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
let x = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 16 * 3 * 3]),
vec![1, 16, 3, 3],
false,
)
.unwrap();
let ehs = Tensor::from_storage(
TensorStorage::cpu(vec![0.01f32; 5 * 24]),
vec![1, 5, 24],
false,
)
.unwrap();
let y = t.forward_xattn(&x, &ehs).unwrap();
assert_eq!(y.shape(), &[1, 16, 3, 3]);
}
#[test]
fn transformer_2d_named_parameters() {
let t = Transformer2DModel::<f32>::new(16, 4, 4, 1, 24, 4).unwrap();
let names: Vec<String> = t.named_parameters().into_iter().map(|(n, _)| n).collect();
for k in [
"norm.weight",
"proj_in.weight",
"proj_in.bias",
"transformer_blocks.0.norm1.weight",
"transformer_blocks.0.attn1.to_q.weight",
"transformer_blocks.0.attn2.to_k.weight",
"transformer_blocks.0.ff.net.0.proj.weight",
"transformer_blocks.0.ff.net.2.weight",
"proj_out.weight",
] {
assert!(names.iter().any(|n| n == k), "missing {k} in {names:?}");
}
}
}