rai-models 0.11.0

ML framework with Ergonomic APIs in Rust
Documentation
use rai::{
    nn::{Activation, Embedding, Linear, Module, RmsNorm},
    AsDType, AsDevice, Module, Shape, Tensor, Type, F32,
};
use std::cell::RefCell;

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
    pub vocab_size: usize,
    pub hidden_size: usize,
    pub intermediate_size: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub num_key_value_heads: usize,
    pub max_position_embeddings: usize,
    pub sliding_window: usize,
    pub max_window_layers: usize,
    pub tie_word_embeddings: bool,
    pub rope_theta: f64,
    pub rms_norm_eps: f64,
    pub use_sliding_window: bool,
    pub hidden_act: Activation,
}

#[derive(Debug, Clone, Module)]
#[module(input = (Tensor, Tensor, usize), output = (Tensor, Tensor), trainable = false)]
struct RotaryEmbedding {
    sin: Tensor,
    cos: Tensor,
}

fn rotate_half(xs: &Tensor) -> Tensor {
    let last_dim = xs.shape_at(-1);
    let xs1 = xs.narrow(-1, 0, last_dim / 2);
    let xs2 = xs.narrow(-1, last_dim / 2, last_dim - last_dim / 2);
    Tensor::cat(&[&xs2.neg(), &xs1], -1)
}

impl RotaryEmbedding {
    pub fn new(cfg: &Config, dtype: impl Type, device: impl AsDevice) -> Self {
        let device = device.device();
        let dim = cfg.hidden_size / cfg.num_attention_heads;
        let max_seq_len = cfg.max_position_embeddings;
        let inv_freq: Vec<_> = (0..dim)
            .step_by(2)
            .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
            .collect();
        let inv_freq_len = inv_freq.len();
        let inv_freq = Tensor::from_array(inv_freq, [1, inv_freq_len], device).to_dtype(dtype);
        let t = Tensor::arange((0u32, max_seq_len as u32), device)
            .to_dtype(dtype)
            .reshape([max_seq_len, 1]);
        let freqs = t.matmul(&inv_freq);
        let freqs = Tensor::cat(&[&freqs, &freqs], -1);
        Self {
            sin: freqs.sin(),
            cos: freqs.cos(),
        }
    }

    pub fn fwd(&self, q: &Tensor, k: &Tensor, seqlen_offset: usize) -> (Tensor, Tensor) {
        let [_b_sz, _h, seq_len, _n_embd]: [usize; 4] = q.shape_before::<4>();
        let cos = self.cos.narrow(0, seqlen_offset, seq_len);
        let sin = self.sin.narrow(0, seqlen_offset, seq_len);
        let cos = &cos.unsqueeze(0).unsqueeze(0); // (1, 1, seq_len, dim)
        let sin = &sin.unsqueeze(0).unsqueeze(0); // (1, 1, seq_len, dim)
        let q_embed = q * cos + rotate_half(q) * sin;
        let k_embed = k * cos + rotate_half(k) * sin;
        (q_embed, k_embed)
    }
}

#[derive(Debug, Clone, Module)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
    gate_proj: Linear,
    up_proj: Linear,
    down_proj: Linear,
    act_fn: Activation,
}

impl MLP {
    pub fn new(cfg: &Config, dtype: impl Type, device: impl AsDevice) -> Self {
        let device = device.device();
        let hidden_size = cfg.hidden_size;
        let intermediate_size = cfg.intermediate_size;
        let gate_proj = Linear::new(hidden_size, intermediate_size, false, dtype, device);
        let up_proj = Linear::new(hidden_size, intermediate_size, false, dtype, device);
        let down_proj = Linear::new(intermediate_size, hidden_size, false, dtype, device);
        let act_fn = cfg.hidden_act;
        Self {
            gate_proj,
            up_proj,
            down_proj,
            act_fn,
        }
    }

    pub fn fwd(&self, xs: &Tensor) -> Tensor {
        let lhs = xs.apply(&self.gate_proj).apply(&self.act_fn);
        let rhs = xs.apply(&self.up_proj);
        (lhs * rhs).apply(&self.down_proj)
    }
}

#[derive(Debug, Clone, Module)]
#[module(input = (Tensor, Option<Tensor>, usize))]
struct Attention {
    q_proj: Linear,
    k_proj: Linear,
    v_proj: Linear,
    o_proj: Linear,
    #[param(skip)]
    num_heads: usize,
    #[param(skip)]
    num_kv_heads: usize,
    #[param(skip)]
    num_kv_groups: usize,
    #[param(skip)]
    head_dim: usize,
    #[param(skip)]
    hidden_size: usize,
    #[param(skip)]
    rotary_emb: RotaryEmbedding,
    #[param(skip)]
    kv_cache: RefCell<Option<(Tensor, Tensor)>>,
}

impl Attention {
    pub fn new(cfg: &Config, dtype: impl Type, device: impl AsDevice) -> Self {
        let device = device.device();
        let hidden_size = cfg.hidden_size;
        let num_heads = cfg.num_attention_heads;
        let num_kv_heads = cfg.num_key_value_heads;
        let head_dim = hidden_size / num_heads;
        let q_proj = Linear::new(hidden_size, hidden_size, true, dtype, device);
        let k_proj = Linear::new(hidden_size, hidden_size, true, dtype, device);
        let v_proj = Linear::new(hidden_size, hidden_size, true, dtype, device);
        let o_proj = Linear::new(hidden_size, hidden_size, false, dtype, device);
        let rotary_emb = RotaryEmbedding::new(cfg, dtype, device);
        let kv_cache = RefCell::new(None);
        Self {
            q_proj,
            k_proj,
            v_proj,
            o_proj,
            num_heads,
            num_kv_heads,
            num_kv_groups: 1,
            head_dim,
            hidden_size,
            rotary_emb,
            kv_cache,
        }
    }

    fn repeat_kv(&self, xs: Tensor) -> Tensor {
        let n_rep = self.num_kv_groups;
        if n_rep == 1 {
            xs
        } else {
            let [b_sz, num_kv_heads, seq_len, head_dim] = xs.shape_before::<4>();
            xs.unsqueeze(2)
                .broadcast_to([b_sz, num_kv_heads, n_rep, seq_len, head_dim])
                .reshape([b_sz, num_kv_heads * n_rep, seq_len, head_dim])
        }
    }

    pub fn fwd(
        &self,
        xs: &Tensor,
        attention_mask: Option<&Tensor>,
        seqlen_offset: usize,
    ) -> Tensor {
        let [b_sz, q_len] = xs.shape_before::<2>();
        let query_states = self.q_proj.forward(xs);
        let key_states = self.k_proj.forward(xs);
        let value_states = self.v_proj.forward(xs);
        let query_states = query_states
            .reshape([b_sz, q_len, self.num_heads, self.head_dim])
            .transpose(1, 2);
        let key_states = key_states
            .reshape([b_sz, q_len, self.num_kv_heads, self.head_dim])
            .transpose(1, 2);
        let value_states = value_states
            .reshape([b_sz, q_len, self.num_kv_heads, self.head_dim])
            .transpose(1, 2);
        let (query_states, key_states) =
            self.rotary_emb
                .fwd(&query_states, &key_states, seqlen_offset);

        let kv_cache = self.kv_cache.borrow();
        let (key_states, value_states) = match &*kv_cache {
            None => (key_states, value_states),
            Some((prev_k, prev_v)) => {
                let key_states = Tensor::cat(&[prev_k, &key_states], 2);
                let value_states = Tensor::cat(&[prev_v, &value_states], 2);
                (key_states, value_states)
            }
        };
        drop(kv_cache);
        self.kv_cache
            .replace(Some((key_states.clone(), value_states.clone())));
        let key_states = self.repeat_kv(key_states);
        let value_states = self.repeat_kv(value_states);
        let attn_output = {
            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
            let attn_weights = query_states.matmul(&key_states.transpose(2, 3)) * scale;
            let attn_weights = match attention_mask {
                None => attn_weights,
                Some(mask) => attn_weights + mask,
            };
            let attn_weights = attn_weights.softmax(-1);
            attn_weights.matmul(&value_states)
        };
        attn_output
            .transpose(1, 2)
            .reshape([b_sz, q_len, self.hidden_size])
            .apply(&self.o_proj)
    }

    pub fn clear_kv_cache(&self) {
        self.kv_cache.replace(None);
    }
}

#[derive(Debug, Clone, Module)]
#[module(input = (Tensor, Option<Tensor>, usize))]
struct DecoderLayer {
    self_attn: Attention,
    mlp: MLP,
    input_layernorm: RmsNorm,
    post_attention_layernorm: RmsNorm,
}

impl DecoderLayer {
    pub fn new(cfg: &Config, dtype: impl Type, device: impl AsDevice) -> Self {
        let device = device.device();
        let self_attn = Attention::new(cfg, dtype, device);
        let mlp = MLP::new(cfg, dtype, device);
        let input_layernorm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, dtype, device);
        let post_attention_layernorm =
            RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, dtype, device);
        Self {
            self_attn,
            mlp,
            input_layernorm,
            post_attention_layernorm,
        }
    }

    pub fn fwd(
        &self,
        xs: &Tensor,
        attention_mask: Option<&Tensor>,
        seqlen_offset: usize,
    ) -> Tensor {
        let residual = xs;
        let xs = self.input_layernorm.forward(xs);
        let xs = self.self_attn.fwd(&xs, attention_mask, seqlen_offset);
        let xs = xs + residual;
        let residual = &xs;
        let xs = xs.apply(&self.post_attention_layernorm).apply(&self.mlp);
        residual + xs
    }

    pub fn clear_kv_cache(&self) {
        self.self_attn.clear_kv_cache()
    }
}

#[derive(Debug, Clone, Module)]
#[module(input = (Tensor,usize))]
pub struct Model {
    #[param(rename = "model.embed_tokens")]
    embed_tokens: Embedding,
    #[param(rename = "model.layers")]
    layers: Vec<DecoderLayer>,
    #[param(rename = "model.norm")]
    norm: RmsNorm,
    lm_head: Linear,
    #[param(skip)]
    sliding_window: usize,
}

impl Model {
    pub fn new(cfg: &Config, dtype: impl Type, device: impl AsDevice) -> Self {
        let device = device.device();
        let embed_tokens = Embedding::new(cfg.vocab_size, cfg.hidden_size, dtype, device);
        let layers = (0..cfg.num_hidden_layers)
            .map(|_| DecoderLayer::new(cfg, dtype, device))
            .collect();
        let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, dtype, device);
        let lm_head = Linear::new(cfg.hidden_size, cfg.vocab_size, false, dtype, device);
        let sliding_window = cfg.sliding_window;
        Self {
            embed_tokens,
            layers,
            norm,
            lm_head,
            sliding_window,
        }
    }

    fn prepare_decoder_attention_mask(
        &self,
        b_size: usize,
        tgt_len: usize,
        seqlen_offset: usize,
        dtype: impl AsDType,
        device: impl AsDevice,
    ) -> Tensor {
        let device = device.device();
        // Sliding window mask?
        let mask: Vec<_> = (0..tgt_len)
            .flat_map(|i| {
                (0..tgt_len).map(move |j| {
                    if i < j || j + self.sliding_window < i {
                        f32::NEG_INFINITY
                    } else {
                        0.
                    }
                })
            })
            .collect();
        let mask = Tensor::from_array(mask, [tgt_len, tgt_len], device);
        let mask = if seqlen_offset > 0 {
            let mask0 = Tensor::zeros([tgt_len, seqlen_offset], F32, device);
            Tensor::cat(&[&mask0, &mask], -1)
        } else {
            mask
        };
        mask.broadcast_to([b_size, 1, tgt_len, tgt_len + seqlen_offset])
            .to_dtype(dtype)
    }

    pub fn fwd(&self, input: &Tensor, seqlen_offset: usize) -> Tensor {
        let [b_size, seq_len] = input.shape_before::<2>();
        let attention_mask = if seq_len <= 1 {
            None
        } else {
            let mask = self.prepare_decoder_attention_mask(
                b_size,
                seq_len,
                seqlen_offset,
                self.lm_head.weight().dtype(),
                self.lm_head.weight().device(),
            );
            Some(mask)
        };
        let mut xs = self.embed_tokens.forward(input);
        for layer in &self.layers {
            xs = layer.fwd(&xs, attention_mask.as_ref(), seqlen_offset);
        }
        xs.narrow(1, seq_len - 1, 1)
            .apply(&self.norm)
            .apply(&self.lm_head)
    }

    pub fn clear_kv_cache(&self) {
        for layer in self.layers.iter() {
            layer.clear_kv_cache()
        }
    }
}