use crate as burn;
use crate::nn::cache::TensorCache;
use crate::{
    config::Config,
    module::Module,
    nn,
    tensor::{activation, backend::Backend, Bool, Tensor},
};
use libm::sqrtf;
#[derive(Config)]
pub struct MultiHeadAttentionConfig {
    d_model: usize,
    n_heads: usize,
    #[config(default = 0.1)]
    dropout: f64,
    #[config(default = -1.0e4)]
    min_float: f64,
}
#[derive(Module, Debug)]
pub struct MultiHeadAttention<B: Backend> {
    query: nn::Linear<B>,
    key: nn::Linear<B>,
    value: nn::Linear<B>,
    output: nn::Linear<B>,
    dropout: nn::Dropout,
    activation: nn::GELU,
    n_heads: usize,
    d_k: usize,
    min_float: f64,
}
#[derive(Debug, Clone)]
pub struct MhaInput<B: Backend> {
    query: Tensor<B, 3>,
    key: Tensor<B, 3>,
    value: Tensor<B, 3>,
    mask_pad: Option<Tensor<B, 2, Bool>>,
    mask_attn: Option<Tensor<B, 3, Bool>>,
}
impl MultiHeadAttentionConfig {
    pub fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
        let linear = |config: &Self| nn::LinearConfig::new(config.d_model, config.d_model).init();
        MultiHeadAttention {
            query: linear(self),
            key: linear(self),
            value: linear(self),
            output: linear(self),
            dropout: nn::DropoutConfig::new(self.dropout).init(),
            activation: nn::GELU::new(),
            n_heads: self.n_heads,
            d_k: self.d_model / self.n_heads,
            min_float: self.min_float,
        }
    }
    pub fn init_with<B: Backend>(
        &self,
        record: MultiHeadAttentionRecord<B>,
    ) -> MultiHeadAttention<B> {
        let linear = |config: &Self, record| {
            nn::LinearConfig::new(config.d_model, config.d_model).init_with(record)
        };
        MultiHeadAttention {
            query: linear(self, record.query),
            key: linear(self, record.key),
            value: linear(self, record.value),
            output: linear(self, record.output),
            dropout: nn::DropoutConfig::new(self.dropout).init(),
            activation: nn::GELU::new(),
            n_heads: self.n_heads,
            d_k: self.d_model / self.n_heads,
            min_float: self.min_float,
        }
    }
}
impl<B: Backend> MhaInput<B> {
    pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
        Self {
            query: tensor.clone(),
            key: tensor.clone(),
            value: tensor,
            mask_pad: None,
            mask_attn: None,
        }
    }
    pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
        Self {
            query,
            key,
            value,
            mask_pad: None,
            mask_attn: None,
        }
    }
    pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
        self.mask_pad = Some(mask_pad);
        self
    }
    pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
        self.mask_attn = Some(mask_attn);
        self
    }
}
#[derive(Debug, Clone)]
pub struct MhaOutput<B: Backend> {
    pub weights: Tensor<B, 4>,
    pub context: Tensor<B, 3>,
}
impl<B: Backend> MultiHeadAttention<B> {
    pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
        let [batch_size, seq_length_1, d_model] = input.query.dims();
        let query = self.attention_linear(input.query, &self.query);
        let key = self.attention_linear(input.key, &self.key);
        let value = self.attention_linear(input.value, &self.value);
        let attn_scores = self.attn_scores(query, key);
        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
        let context = weights.clone().matmul(value);
        let context = context
            .swap_dims(1, 2)
            .reshape([batch_size, seq_length_1, d_model]);
        let context = self.output.forward(context);
        MhaOutput { weights, context }
    }
    pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
        let [batch_size, seq_length_1, d_model] = input.query.dims();
        let query = cache
            .query
            .forward(input.query, |t| self.attention_linear(t, &self.query));
        let key = cache
            .key
            .forward(input.key, |t| self.attention_linear(t, &self.key));
        let value = cache
            .value
            .forward(input.value, |t| self.attention_linear(t, &self.value));
        let attn_scores = self.attn_scores(query, key);
        let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
        let context = weights.clone().matmul(value);
        let context = context
            .swap_dims(1, 2)
            .reshape([batch_size, seq_length_1, d_model]);
        let context = cache.output.forward(context, |t| self.output.forward(t));
        MhaOutput { weights, context }
    }
    fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
        let attn_scores = query
            .matmul(key.transpose())
            .div_scalar(sqrtf(self.d_k as f32));
        self.dropout.forward(attn_scores)
    }
    fn attn_weights(
        &self,
        mut attn_scores: Tensor<B, 4>,
        mask_pad: Option<Tensor<B, 2, Bool>>,
        mask_attn: Option<Tensor<B, 3, Bool>>,
    ) -> Tensor<B, 4> {
        if let Some(mask_pad) = mask_pad {
            let [batch_size, seq_length] = mask_pad.dims();
            attn_scores = attn_scores.mask_fill(
                mask_pad.reshape([batch_size, 1, 1, seq_length]),
                self.min_float,
            );
        }
        if let Some(mask_attn) = mask_attn {
            let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
            attn_scores = attn_scores.mask_fill(
                mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
                self.min_float,
            );
        }
        activation::softmax(attn_scores, 3)
    }
    fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
        let [batch_size, seq_length, _d_model] = x.dims();
        linear
            .forward(x)
            .reshape([batch_size, seq_length, self.n_heads, self.d_k])
            .swap_dims(1, 2)
    }
}
pub struct MhaCache<B: Backend> {
    query: MhaLinearCache<B, 4>,
    key: MhaLinearCache<B, 4>,
    value: MhaLinearCache<B, 4>,
    output: MhaLinearCache<B, 3>,
}
enum MhaLinearCache<B: Backend, const D: usize> {
    Autoregressive(TensorCache<B, D>, usize),
    Full(TensorCache<B, D>),
}
impl<B: Backend> MhaCache<B> {
    pub fn autoregressive() -> Self {
        Self {
            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
            key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
            value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
        }
    }
    pub fn autoregressive_cross_attention() -> Self {
        Self {
            query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
            key: MhaLinearCache::Full(TensorCache::empty()),
            value: MhaLinearCache::Full(TensorCache::empty()),
            output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
        }
    }
}
impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
    pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(
        &mut self,
        tensor: Tensor<B, 3>,
        func: F,
    ) -> Tensor<B, D> {
        match self {
            MhaLinearCache::Autoregressive(cache, dim) => {
                cache.forward_autoregressive(tensor, *dim, func)
            }
            MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
    use alloc::vec::Vec;
    use burn::tensor::{Distribution, Shape};
    use burn_tensor::Int;
    #[test]
    fn test_self_attention_shapes() {
        let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
        let input = MhaInput::self_attn(Tensor::random(
            [batch_size, seq_length, d_model],
            Distribution::Standard,
        ));
        let output = mha.forward(input);
        assert_eq!(
            output.context.shape(),
            Shape::new([batch_size, seq_length, d_model]),
            "Context should have the correct shape",
        );
        assert_eq!(
            output.weights.shape(),
            Shape::new([batch_size, n_heads, seq_length, seq_length]),
            "Weights should have the correct shape",
        );
    }
    #[test]
    fn test_generic_mha_shapes() {
        let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
        let input = MhaInput::new(
            Tensor::random([batch_size, seq_length_1, d_model], Distribution::Standard),
            Tensor::random([batch_size, seq_length_2, d_model], Distribution::Standard),
            Tensor::random([batch_size, seq_length_2, d_model], Distribution::Standard),
        );
        let output = mha.forward(input);
        assert_eq!(
            output.context.shape(),
            Shape::new([batch_size, seq_length_1, d_model]),
            "Context should have the correct shape",
        );
        assert_eq!(
            output.weights.shape(),
            Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
            "Weights should have the correct shape",
        );
    }
    #[test]
    fn test_self_attention_mask_pad() {
        let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
        let mask_pad: Tensor<TestBackend, 2, Int> = Tensor::zeros([batch_size, seq_length]);
        let mask_pad = mask_pad.index_assign(
            [0..batch_size, seq_length - num_padded..seq_length],
            Tensor::ones([batch_size, num_padded]),
        );
        let mask_pad = mask_pad.equal_elem(1);
        let tensor_1 = Tensor::<TestBackend, 3>::random(
            [batch_size, seq_length, d_model],
            Distribution::Standard,
        );
        let tensor_2 = tensor_1.clone().index_assign(
            [
                0..batch_size,
                seq_length - num_padded..seq_length,
                0..d_model,
            ],
            Tensor::random([batch_size, num_padded, d_model], Distribution::Standard),
        );
        let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
        let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
        let output_1 = mha.forward(input_1);
        let output_2 = mha.forward(input_2);
        output_1
            .context
            .index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
            .into_data()
            .assert_approx_eq(
                &output_2
                    .context
                    .index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
                    .into_data(),
                3,
            );
    }
    #[test]
    fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
        let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
        let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
        let tensor = Tensor::<TestBackend, 3>::random(
            [batch_size, seq_length, d_model],
            Distribution::Standard,
        );
        let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
        let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
        let output_1 = mha.forward(input);
        let mut output_2 = Vec::new();
        let mut cache = MhaCache::autoregressive();
        for i in 1..seq_length + 1 {
            let tensor = tensor.clone().index([0..batch_size, 0..i, 0..d_model]);
            let input = MhaInput::self_attn(tensor);
            let next_tok = mha.forward_cache(input, &mut cache).context.index([
                0..batch_size,
                i - 1..i,
                0..d_model,
            ]);
            output_2.push(next_tok);
        }
        let output_2 = Tensor::cat(output_2, 1);
        output_1
            .context
            .into_data()
            .assert_approx_eq(&output_2.into_data(), 3);
    }
}