use crate::error::{Error, Result};
use crate::model::config::UniversalConfig;
use crate::model::mamba::mamba2::Mamba2Config;
use crate::model::traits::Model;
use crate::nn::VarBuilder;
use numr::dtype::DType;
use numr::ops::IndexingOps;
use numr::runtime::Runtime;
pub enum LoadedModel<R: Runtime> {
Llama(Box<super::llama::Llama<R>>),
LlamaTp(Box<super::llama::LlamaTp<R>>),
Mamba2(Box<super::mamba::Mamba2Model<R>>),
Hybrid(Box<super::hybrid::HybridModel<R>>),
Multimodal(Box<super::multimodal::MultimodalModel<R>>),
}
impl<R: Runtime<DType = DType>> LoadedModel<R>
where
R::Client: IndexingOps<R> + crate::quant::DequantOps<R> + numr::ops::TypeConversionOps<R>,
{
pub fn load(config: &UniversalConfig, vb: &mut VarBuilder<R>) -> Result<Self> {
match config.model_type.as_str() {
"mamba2" | "mamba3" => {
let model = super::mamba::Mamba2Model::from_varbuilder(vb, config)?;
Ok(LoadedModel::Mamba2(Box::new(model)))
}
"hybrid" => {
let model = super::hybrid::HybridModel::from_varbuilder(vb, config)?;
Ok(LoadedModel::Hybrid(Box::new(model)))
}
_ if config.vision.is_some() || config.audio.is_some() => {
let model = super::multimodal::MultimodalModel::from_varbuilder(vb, config)?;
Ok(LoadedModel::Multimodal(Box::new(model)))
}
_ if config.attention.is_some() => {
let model = super::llama::Llama::from_varbuilder(vb, config)?;
Ok(LoadedModel::Llama(Box::new(model)))
}
other => Err(Error::ModelError {
reason: format!(
"Unknown model type '{other}' without attention config. \
Only pure SSM models (mamba2/mamba3) and hybrid models are \
supported without attention configuration."
),
}),
}
}
pub fn load_tp(
config: &UniversalConfig,
vb: &mut VarBuilder<R>,
comm: std::sync::Arc<dyn numr::runtime::Communicator>,
) -> Result<Self> {
if config.attention.is_some() {
let model = super::llama::LlamaTp::from_varbuilder(vb, config, comm)?;
Ok(LoadedModel::LlamaTp(Box::new(model)))
} else {
Err(Error::ModelError {
reason: format!(
"Tensor parallelism not supported for model type '{}' \
(requires attention config)",
config.model_type
),
})
}
}
pub fn load_gguf(config: &UniversalConfig, vb: &mut VarBuilder<R>) -> Result<Self> {
Self::load(config, vb)
}
}
impl<R: Runtime<DType = DType>> LoadedModel<R>
where
R::Client: IndexingOps<R>,
{
pub fn needs_kv_cache(&self) -> bool {
match self {
LoadedModel::Llama(_) | LoadedModel::LlamaTp(_) | LoadedModel::Hybrid(_) => true,
LoadedModel::Multimodal(m) => m.llm().needs_kv_cache(),
LoadedModel::Mamba2(_) => false,
}
}
pub fn needs_ssm_state(&self) -> bool {
match self {
LoadedModel::Mamba2(_) | LoadedModel::Hybrid(_) => true,
LoadedModel::Multimodal(m) => m.llm().needs_ssm_state(),
_ => false,
}
}
pub fn model_type(&self) -> &str {
match self {
LoadedModel::Llama(_) | LoadedModel::LlamaTp(_) => "llama",
LoadedModel::Mamba2(_) => "mamba2",
LoadedModel::Hybrid(_) => "hybrid",
LoadedModel::Multimodal(m) => m.config().model_type.as_str(),
}
}
pub fn vocab_size(&self) -> usize {
match self {
LoadedModel::Llama(m) => m.config().vocab_size,
LoadedModel::LlamaTp(m) => m.config().vocab_size,
LoadedModel::Mamba2(m) => m.config().vocab_size,
LoadedModel::Hybrid(m) => m.config().vocab_size,
LoadedModel::Multimodal(m) => m.config().vocab_size,
}
}
pub fn num_layers(&self) -> usize {
match self {
LoadedModel::Llama(m) => m.config().num_layers,
LoadedModel::LlamaTp(m) => m.config().num_layers,
LoadedModel::Mamba2(m) => m.config().num_layers,
LoadedModel::Hybrid(m) => m.config().num_layers,
LoadedModel::Multimodal(m) => m.config().num_layers,
}
}
pub fn hidden_size(&self) -> usize {
match self {
LoadedModel::Llama(m) => m.config().hidden_size,
LoadedModel::LlamaTp(m) => m.config().hidden_size,
LoadedModel::Mamba2(m) => m.config().hidden_size,
LoadedModel::Hybrid(m) => m.config().hidden_size,
LoadedModel::Multimodal(m) => m.config().hidden_size,
}
}
pub fn num_kv_heads(&self) -> Option<usize> {
match self {
LoadedModel::Llama(m) => m.config().attention.as_ref().map(|a| a.kv_heads()),
LoadedModel::LlamaTp(m) => m
.config()
.attention
.as_ref()
.map(|a| a.kv_heads() / m.world_size()),
LoadedModel::Mamba2(_) => None,
LoadedModel::Hybrid(m) => m.config().attention.as_ref().map(|a| a.kv_heads()),
LoadedModel::Multimodal(m) => m.llm().num_kv_heads(),
}
}
pub fn head_dim(&self) -> Option<usize> {
match self {
LoadedModel::Llama(m) => {
let config = m.config();
config
.attention
.as_ref()
.map(|a| a.head_dim(config.hidden_size))
}
LoadedModel::LlamaTp(m) => {
let config = m.config();
config
.attention
.as_ref()
.map(|a| a.head_dim(config.hidden_size))
}
LoadedModel::Mamba2(_) => None,
LoadedModel::Hybrid(m) => {
let config = m.config();
config
.attention
.as_ref()
.map(|a| a.head_dim(config.hidden_size))
}
LoadedModel::Multimodal(m) => m.llm().head_dim(),
}
}
pub fn max_seq_len(&self) -> usize {
match self {
LoadedModel::Llama(m) => m.config().max_seq_len,
LoadedModel::LlamaTp(m) => m.config().max_seq_len,
LoadedModel::Mamba2(m) => m.config().max_seq_len,
LoadedModel::Hybrid(m) => m.config().max_seq_len,
LoadedModel::Multimodal(m) => m.config().max_seq_len,
}
}
pub fn is_moe(&self) -> bool {
match self {
LoadedModel::Llama(m) => m.config().moe.is_some(),
LoadedModel::LlamaTp(m) => m.config().moe.is_some(),
LoadedModel::Mamba2(m) => m.config().moe.is_some(),
LoadedModel::Hybrid(m) => m.config().moe.is_some(),
LoadedModel::Multimodal(m) => m.llm().is_moe(),
}
}
pub fn moe_config(&self) -> Option<&crate::model::config::MoeConfig> {
match self {
LoadedModel::Llama(m) => m.config().moe.as_ref(),
LoadedModel::LlamaTp(m) => m.config().moe.as_ref(),
LoadedModel::Mamba2(m) => m.config().moe.as_ref(),
LoadedModel::Hybrid(m) => m.config().moe.as_ref(),
LoadedModel::Multimodal(m) => m.llm().moe_config(),
}
}
pub fn rope_caches(&self) -> Option<(&numr::autograd::Var<R>, &numr::autograd::Var<R>)> {
match self {
LoadedModel::Llama(m) => Some((m.rope().cos_cache(), m.rope().sin_cache())),
LoadedModel::LlamaTp(_) => None, LoadedModel::Mamba2(_) => None,
LoadedModel::Hybrid(m) => Some((m.rope().cos_cache(), m.rope().sin_cache())),
LoadedModel::Multimodal(m) => m.llm().rope_caches(),
}
}
pub fn mamba_config(&self) -> Option<&Mamba2Config> {
match self {
LoadedModel::Mamba2(m) => Some(m.mamba_config()),
LoadedModel::Hybrid(m) => Some(m.mamba_config()),
LoadedModel::Multimodal(m) => m.llm().mamba_config(),
_ => None,
}
}
}
impl<R: Runtime> std::fmt::Debug for LoadedModel<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LoadedModel::Llama(_) => f.debug_tuple("Llama").finish(),
LoadedModel::LlamaTp(_) => f.debug_tuple("LlamaTp").finish(),
LoadedModel::Mamba2(_) => f.debug_tuple("Mamba2").finish(),
LoadedModel::Hybrid(_) => f.debug_tuple("Hybrid").finish(),
LoadedModel::Multimodal(_) => f.debug_tuple("Multimodal").finish(),
}
}
}