candle_transformers/models/
persimmon.rs

1//! Persimmon Model
2//!
3//! A transformer language model for efficient inference and general-purpose tasks. The model uses a standard transformer architecture with:
4//! - Layer normalization for Q/K attention
5//! - RoPE embeddings with partial rotary factor
6//! - ReLU activation
7//! - Separate number of attention heads and KV heads
8//!
9//! References:
10//! - 💻 [Hugging Face Implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/modeling_persimmon.py)
11//! - 💻 [Persimmon Config](https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py)
12//! - 🤗 [Hugging Face](https://huggingface.co/adept/persimmon-8b-base)
13//!
14
15use candle::DType;
16use serde::Deserialize;
17
18pub const DTYPE: DType = DType::F32;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
21#[serde(rename_all = "lowercase")]
22pub enum PositionEmbeddingType {
23    Absolute,
24    Alibi,
25}
26
27// https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py
28#[derive(Debug, Clone, PartialEq, Deserialize)]
29pub struct Config {
30    pub vocab_size: usize,
31    pub hidden_size: usize,
32    pub intermediate_size: usize,
33    pub num_hidden_layers: usize,
34    pub num_attention_heads: usize,
35    pub num_key_value_heads: usize,
36    pub hidden_act: candle_nn::Activation,
37    pub max_position_embeddings: usize,
38    pub initializer_range: f64,
39    pub layer_norm_eps: f64,
40    pub rms_norm_eps: f64,
41    pub use_cache: bool,
42    pub tie_word_embeddings: bool,
43    pub rope_theta: f64,
44    pub qk_layernorm: bool,
45    pub partial_rotary_factor: f64,
46}
47
48impl Config {
49    pub fn base_8b() -> Self {
50        // https://huggingface.co/adept/persimmon-8b-base/blob/main/config.json
51        Self {
52            hidden_act: candle_nn::Activation::Relu,
53            hidden_size: 4096,
54            initializer_range: 0.02,
55            intermediate_size: 16384,
56            layer_norm_eps: 1e-05,
57            max_position_embeddings: 16384,
58            num_attention_heads: 64,
59            num_hidden_layers: 36,
60            num_key_value_heads: 64,
61            qk_layernorm: true,
62            rms_norm_eps: 1e-06,
63            rope_theta: 25000.0,
64            tie_word_embeddings: false,
65            use_cache: true,
66            vocab_size: 262144,
67            partial_rotary_factor: 0.5,
68        }
69    }
70}