candle_transformers/models/
persimmon.rs1use candle::DType;
16use serde::Deserialize;
17
18pub const DTYPE: DType = DType::F32;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
21#[serde(rename_all = "lowercase")]
22pub enum PositionEmbeddingType {
23 Absolute,
24 Alibi,
25}
26
27#[derive(Debug, Clone, PartialEq, Deserialize)]
29pub struct Config {
30 pub vocab_size: usize,
31 pub hidden_size: usize,
32 pub intermediate_size: usize,
33 pub num_hidden_layers: usize,
34 pub num_attention_heads: usize,
35 pub num_key_value_heads: usize,
36 pub hidden_act: candle_nn::Activation,
37 pub max_position_embeddings: usize,
38 pub initializer_range: f64,
39 pub layer_norm_eps: f64,
40 pub rms_norm_eps: f64,
41 pub use_cache: bool,
42 pub tie_word_embeddings: bool,
43 pub rope_theta: f64,
44 pub qk_layernorm: bool,
45 pub partial_rotary_factor: f64,
46}
47
48impl Config {
49 pub fn base_8b() -> Self {
50 Self {
52 hidden_act: candle_nn::Activation::Relu,
53 hidden_size: 4096,
54 initializer_range: 0.02,
55 intermediate_size: 16384,
56 layer_norm_eps: 1e-05,
57 max_position_embeddings: 16384,
58 num_attention_heads: 64,
59 num_hidden_layers: 36,
60 num_key_value_heads: 64,
61 qk_layernorm: true,
62 rms_norm_eps: 1e-06,
63 rope_theta: 25000.0,
64 tie_word_embeddings: false,
65 use_cache: true,
66 vocab_size: 262144,
67 partial_rotary_factor: 0.5,
68 }
69 }
70}