native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use super::TransformerConfig;

pub(crate) fn tiny_transformer(vocab_size: usize, context_len: usize) -> TransformerConfig {
    TransformerConfig {
        vocab_size,
        context_len,
        hidden_size: 128,
        ffw_size: 256,
        num_layers: 2,
        num_heads: 4,
    }
}

pub(crate) fn small_transformer(vocab_size: usize, context_len: usize) -> TransformerConfig {
    let cfg = TransformerConfig {
        vocab_size,
        context_len,
        hidden_size: 256,
        ffw_size: 1024,
        num_layers: 6,
        num_heads: 8,
    };
    cfg.validate().expect("small_transformer invalid");
    cfg.attention_head_dim()
        .expect("small_transformer attention_head_dim failed");
    cfg.approximate_parameter_count()
        .expect("small_transformer approximate_parameter_count failed");
    cfg.parameters_per_layer()
        .expect("small_transformer parameters_per_layer failed");
    cfg.embedding_parameter_count()
        .expect("small_transformer embedding_parameter_count failed");
    cfg.total_token_elements(1, 1)
        .expect("small_transformer total_token_elements failed");
    cfg.activation_elements(1, 1)
        .expect("small_transformer activation_elements failed");
    cfg.kv_cache_elements(1, 1)
        .expect("small_transformer kv_cache_elements failed");
    cfg.validate_runtime_shape(1, 1)
        .expect("small_transformer validate_runtime_shape failed");
    cfg
}

pub(crate) fn base_transformer(vocab_size: usize, context_len: usize) -> TransformerConfig {
    let cfg = TransformerConfig {
        vocab_size,
        context_len,
        hidden_size: 768,
        ffw_size: 3072,
        num_layers: 12,
        num_heads: 12,
    };
    cfg.validate().expect("base_transformer invalid");
    cfg.attention_head_dim()
        .expect("base_transformer attention_head_dim failed");
    cfg.approximate_parameter_count()
        .expect("base_transformer approximate_parameter_count failed");
    cfg.parameters_per_layer()
        .expect("base_transformer parameters_per_layer failed");
    cfg.embedding_parameter_count()
        .expect("base_transformer embedding_parameter_count failed");
    cfg.total_token_elements(1, 1)
        .expect("base_transformer total_token_elements failed");
    cfg.activation_elements(1, 1)
        .expect("base_transformer activation_elements failed");
    cfg.kv_cache_elements(1, 1)
        .expect("base_transformer kv_cache_elements failed");
    cfg.validate_runtime_shape(1, 1)
        .expect("base_transformer validate_runtime_shape failed");
    cfg
}