flodl 0.2.2

floDl — a flow-graph deep learning framework built on libtorch
Documentation
use crate::autograd::Variable;
use crate::tensor::{Device, Result, Tensor};

use super::init;
use super::parameter::Parameter;
use super::Module;

/// Multi-head attention mechanism.
///
/// Implements `MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O`
/// where each `head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)`.
///
/// Supports optional causal masking and key-value attention masks.
///
/// ```ignore
/// let mha = MultiheadAttention::on_device(512, 8, device)?;
/// // Self-attention: query = key = value
/// let y = mha.forward(&x)?;
/// // Cross-attention or masked: use forward_ext
/// let y = mha.forward_ext(&query, &key, &value, Some(&mask))?;
/// ```
pub struct MultiheadAttention {
    q_proj: Linear,
    k_proj: Linear,
    v_proj: Linear,
    out_proj: Linear,
    num_heads: i64,
    head_dim: i64,
    scale: f64,
}

struct Linear {
    weight: Parameter,
    bias: Parameter,
}

impl Linear {
    fn on_device(in_features: i64, out_features: i64, device: Device) -> Result<Self> {
        let w = init::xavier_uniform(
            &[out_features, in_features], in_features, out_features, device,
        )?;
        let b = Tensor::zeros(
            &[out_features],
            crate::tensor::TensorOptions { dtype: crate::tensor::DType::Float32, device },
        )?;
        Ok(Linear {
            weight: Parameter::new(w, "weight"),
            bias: Parameter::new(b, "bias"),
        })
    }

    fn forward(&self, input: &Variable) -> Result<Variable> {
        crate::autograd::linear(
            input,
            &self.weight.variable,
            Some(&self.bias.variable),
        )
    }

    fn parameters(&self, prefix: &str) -> Vec<Parameter> {
        vec![
            Parameter {
                variable: self.weight.variable.clone(),
                name: format!("{prefix}.weight"),
            },
            Parameter {
                variable: self.bias.variable.clone(),
                name: format!("{prefix}.bias"),
            },
        ]
    }
}

impl MultiheadAttention {
    /// Create a multi-head attention module on CPU.
    pub fn new(embed_dim: i64, num_heads: i64) -> Result<Self> {
        Self::on_device(embed_dim, num_heads, Device::CPU)
    }

    /// Create a multi-head attention module on a specific device.
    pub fn on_device(embed_dim: i64, num_heads: i64, device: Device) -> Result<Self> {
        assert!(
            embed_dim % num_heads == 0,
            "embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
        );
        let head_dim = embed_dim / num_heads;

        Ok(MultiheadAttention {
            q_proj: Linear::on_device(embed_dim, embed_dim, device)?,
            k_proj: Linear::on_device(embed_dim, embed_dim, device)?,
            v_proj: Linear::on_device(embed_dim, embed_dim, device)?,
            out_proj: Linear::on_device(embed_dim, embed_dim, device)?,
            num_heads,
            head_dim,
            scale: 1.0 / (head_dim as f64).sqrt(),
        })
    }

    /// Full attention forward with separate query, key, value and optional mask.
    ///
    /// Shapes:
    /// - query: `[batch, seq_q, embed_dim]`
    /// - key:   `[batch, seq_k, embed_dim]`
    /// - value: `[batch, seq_k, embed_dim]`
    /// - mask:  `[seq_q, seq_k]` or `[batch, 1, seq_q, seq_k]` (true/non-zero = masked positions)
    ///
    /// Returns: `[batch, seq_q, embed_dim]`
    pub fn forward_ext(
        &self,
        query: &Variable,
        key: &Variable,
        value: &Variable,
        mask: Option<&Tensor>,
    ) -> Result<Variable> {
        let batch = query.shape()[0];
        let seq_q = query.shape()[1];
        let seq_k = key.shape()[1];

        // Project Q, K, V
        let q = self.q_proj.forward(query)?;
        let k = self.k_proj.forward(key)?;
        let v = self.v_proj.forward(value)?;

        // Reshape to [batch, num_heads, seq, head_dim]
        let q = q.reshape(&[batch, seq_q, self.num_heads, self.head_dim])?
                 .transpose(1, 2)?;
        let k = k.reshape(&[batch, seq_k, self.num_heads, self.head_dim])?
                 .transpose(1, 2)?;
        let v = v.reshape(&[batch, seq_k, self.num_heads, self.head_dim])?
                 .transpose(1, 2)?;

        // Attention scores: [batch, heads, seq_q, seq_k]
        let k_t = k.transpose(2, 3)?;
        let mut scores = q.matmul(&k_t)?.mul_scalar(self.scale)?;

        // Apply mask (true/non-zero positions are filled with -inf)
        if let Some(m) = mask {
            scores = scores.masked_fill(m, f64::NEG_INFINITY)?;
        }

        // Softmax over key dimension
        let attn = scores.softmax(-1)?;

        // Weighted sum of values: [batch, heads, seq_q, head_dim]
        let out = attn.matmul(&v)?;

        // Reshape back: [batch, seq_q, embed_dim]
        let out = out.transpose(1, 2)?
                     .reshape(&[batch, seq_q, self.num_heads * self.head_dim])?;

        // Output projection
        self.out_proj.forward(&out)
    }
}

impl Module for MultiheadAttention {
    fn name(&self) -> &str { "multihead_attention" }

    /// Self-attention forward: query = key = value = input, no mask.
    fn forward(&self, input: &Variable) -> Result<Variable> {
        self.forward_ext(input, input, input, None)
    }

    fn parameters(&self) -> Vec<Parameter> {
        let mut params = Vec::new();
        params.extend(self.q_proj.parameters("q_proj"));
        params.extend(self.k_proj.parameters("k_proj"));
        params.extend(self.v_proj.parameters("v_proj"));
        params.extend(self.out_proj.parameters("out_proj"));
        params
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tensor::test_device;

    #[test]
    fn test_mha_self_attention() {
        let device = test_device();
        let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
        let opts = crate::tensor::test_opts();
        let x = Variable::new(
            Tensor::randn(&[2, 4, 8], opts).unwrap(), // [batch=2, seq=4, dim=8]
            false,
        );
        let y = mha.forward(&x).unwrap();
        assert_eq!(y.shape(), vec![2, 4, 8]);
    }

    #[test]
    fn test_mha_cross_attention() {
        let device = test_device();
        let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
        let opts = crate::tensor::test_opts();
        let q = Variable::new(Tensor::randn(&[1, 3, 8], opts).unwrap(), false);
        let kv = Variable::new(Tensor::randn(&[1, 5, 8], opts).unwrap(), false);
        let y = mha.forward_ext(&q, &kv, &kv, None).unwrap();
        assert_eq!(y.shape(), vec![1, 3, 8]); // seq_q=3, not seq_k=5
    }

    #[test]
    fn test_mha_causal_mask() {
        let device = test_device();
        let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
        let opts = crate::tensor::test_opts();
        let x = Variable::new(Tensor::randn(&[1, 4, 8], opts).unwrap(), false);

        // Causal mask: upper triangle = true (masked)
        let mask = Tensor::ones(&[4, 4], opts).unwrap().triu(1).unwrap();
        let y = mha.forward_ext(&x, &x, &x, Some(&mask)).unwrap();
        assert_eq!(y.shape(), vec![1, 4, 8]);
    }

    #[test]
    fn test_mha_gradient() {
        let device = test_device();
        let mha = MultiheadAttention::on_device(8, 2, device).unwrap();
        let opts = crate::tensor::test_opts();
        let x = Variable::new(Tensor::randn(&[1, 3, 8], opts).unwrap(), true);
        let y = mha.forward(&x).unwrap();
        let loss = y.sum().unwrap();
        loss.backward().unwrap();

        let grad = x.grad().unwrap();
        assert_eq!(grad.shape(), vec![1, 3, 8]);
    }

    #[test]
    fn test_mha_parameters() {
        let mha = MultiheadAttention::new(16, 4).unwrap();
        let params = mha.parameters();
        // 4 projections * 2 (weight + bias) = 8 parameters
        assert_eq!(params.len(), 8);
    }
}