use super::TransformerConfig;
pub(crate) fn tiny_transformer(vocab_size: usize, context_len: usize) -> TransformerConfig {
TransformerConfig {
vocab_size,
context_len,
hidden_size: 128,
ffw_size: 256,
num_layers: 2,
num_heads: 4,
}
}
pub(crate) fn small_transformer(vocab_size: usize, context_len: usize) -> TransformerConfig {
let cfg = TransformerConfig {
vocab_size,
context_len,
hidden_size: 256,
ffw_size: 1024,
num_layers: 6,
num_heads: 8,
};
cfg.validate().expect("small_transformer invalid");
cfg.attention_head_dim()
.expect("small_transformer attention_head_dim failed");
cfg.approximate_parameter_count()
.expect("small_transformer approximate_parameter_count failed");
cfg.parameters_per_layer()
.expect("small_transformer parameters_per_layer failed");
cfg.embedding_parameter_count()
.expect("small_transformer embedding_parameter_count failed");
cfg.total_token_elements(1, 1)
.expect("small_transformer total_token_elements failed");
cfg.activation_elements(1, 1)
.expect("small_transformer activation_elements failed");
cfg.kv_cache_elements(1, 1)
.expect("small_transformer kv_cache_elements failed");
cfg.validate_runtime_shape(1, 1)
.expect("small_transformer validate_runtime_shape failed");
cfg
}
pub(crate) fn base_transformer(vocab_size: usize, context_len: usize) -> TransformerConfig {
let cfg = TransformerConfig {
vocab_size,
context_len,
hidden_size: 768,
ffw_size: 3072,
num_layers: 12,
num_heads: 12,
};
cfg.validate().expect("base_transformer invalid");
cfg.attention_head_dim()
.expect("base_transformer attention_head_dim failed");
cfg.approximate_parameter_count()
.expect("base_transformer approximate_parameter_count failed");
cfg.parameters_per_layer()
.expect("base_transformer parameters_per_layer failed");
cfg.embedding_parameter_count()
.expect("base_transformer embedding_parameter_count failed");
cfg.total_token_elements(1, 1)
.expect("base_transformer total_token_elements failed");
cfg.activation_elements(1, 1)
.expect("base_transformer activation_elements failed");
cfg.kv_cache_elements(1, 1)
.expect("base_transformer kv_cache_elements failed");
cfg.validate_runtime_shape(1, 1)
.expect("base_transformer validate_runtime_shape failed");
cfg
}