pub(crate) mod auto_device_map;
mod diffusion_loaders;
mod embedding_loaders;
mod multimodal_loaders;
mod normal_loaders;
pub use auto_device_map::AutoDeviceMapParams;
use auto_device_map::NonMappedSubModel;
use std::{
fmt::{self, Debug},
path::PathBuf,
str::FromStr,
sync::Arc,
};
use anyhow::Result;
use as_any::AsAny;
use candle_core::{DType, Device};
use mistralrs_quant::{IsqType, QuantizedConfig};
use serde::Deserialize;
use tokio::sync::Mutex;
pub use normal_loaders::{
AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
};
pub use multimodal_loaders::{
AutoMultimodalLoader, Gemma3Loader, Gemma3nLoader, Gemma4Loader, Idefics2Loader,
Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader, Mistral3Loader,
MultimodalLoaderType, MultimodalModel, MultimodalModelLoader, Phi3VLoader, Phi4MMLoader,
Qwen2VLLoader, Qwen2_5VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, Qwen3_5Loader,
Qwen3_5MoeLoader, VLlama4Loader, VLlamaLoader, VoxtralLoader,
};
pub use embedding_loaders::{
AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
Qwen3EmbeddingLoader,
};
pub use diffusion_loaders::{
DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
DiffusionModelPathsInner, FluxLoader,
};
use crate::{
matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
};
use super::{paths::AdapterPaths, Pipeline};
pub trait ModelPaths: AsAny + Debug + Send + Sync {
fn get_weight_filenames(&self) -> &[PathBuf];
fn get_config_filename(&self) -> &PathBuf;
fn get_tokenizer_filename(&self) -> &PathBuf;
fn get_template_filename(&self) -> &Option<PathBuf>;
fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
fn get_preprocessor_config(&self) -> &Option<PathBuf>;
fn get_processor_config(&self) -> &Option<PathBuf>;
fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
fn get_adapter_paths(&self) -> &AdapterPaths;
fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
}
#[derive(Clone, Debug)]
pub struct LocalModelPaths<P: Debug> {
pub tokenizer_filename: P,
pub config_filename: P,
pub template_filename: Option<P>,
pub filenames: Vec<P>,
pub adapter_paths: AdapterPaths,
pub gen_conf: Option<P>,
pub preprocessor_config: Option<P>,
pub processor_config: Option<P>,
pub chat_template_json_filename: Option<P>,
}
impl<P: Debug> LocalModelPaths<P> {
#[allow(clippy::too_many_arguments)]
pub fn new(
tokenizer_filename: P,
config_filename: P,
template_filename: P,
filenames: Vec<P>,
adapter_paths: AdapterPaths,
gen_conf: Option<P>,
preprocessor_config: Option<P>,
processor_config: Option<P>,
chat_template_json_filename: Option<P>,
) -> Self {
Self {
tokenizer_filename,
config_filename,
template_filename: Some(template_filename),
filenames,
adapter_paths,
gen_conf,
preprocessor_config,
processor_config,
chat_template_json_filename,
}
}
}
impl ModelPaths for LocalModelPaths<PathBuf> {
fn get_config_filename(&self) -> &PathBuf {
&self.config_filename
}
fn get_tokenizer_filename(&self) -> &PathBuf {
&self.tokenizer_filename
}
fn get_weight_filenames(&self) -> &[PathBuf] {
&self.filenames
}
fn get_template_filename(&self) -> &Option<PathBuf> {
&self.template_filename
}
fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
self.gen_conf.as_ref()
}
fn get_preprocessor_config(&self) -> &Option<PathBuf> {
&self.preprocessor_config
}
fn get_processor_config(&self) -> &Option<PathBuf> {
&self.processor_config
}
fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
&self.chat_template_json_filename
}
fn get_adapter_paths(&self) -> &AdapterPaths {
&self.adapter_paths
}
fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
None
}
}
#[derive(Clone, Debug)]
pub struct EmbeddingModelPaths<P: Debug> {
pub tokenizer_filename: P,
pub config_filename: P,
pub modules: Vec<EmbeddingModulePaths>,
pub filenames: Vec<P>,
pub adapter_paths: AdapterPaths,
}
impl<P: Debug> EmbeddingModelPaths<P> {
#[allow(clippy::too_many_arguments)]
pub fn new(
tokenizer_filename: P,
config_filename: P,
filenames: Vec<P>,
adapter_paths: AdapterPaths,
modules: Vec<EmbeddingModulePaths>,
) -> Self {
Self {
tokenizer_filename,
config_filename,
filenames,
adapter_paths,
modules,
}
}
}
impl ModelPaths for EmbeddingModelPaths<PathBuf> {
fn get_config_filename(&self) -> &PathBuf {
&self.config_filename
}
fn get_tokenizer_filename(&self) -> &PathBuf {
&self.tokenizer_filename
}
fn get_weight_filenames(&self) -> &[PathBuf] {
&self.filenames
}
fn get_template_filename(&self) -> &Option<PathBuf> {
&None
}
fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
None
}
fn get_preprocessor_config(&self) -> &Option<PathBuf> {
&None
}
fn get_processor_config(&self) -> &Option<PathBuf> {
&None
}
fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
&None
}
fn get_adapter_paths(&self) -> &AdapterPaths {
&self.adapter_paths
}
fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
Some(&self.modules)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum TokenSource {
Literal(String),
EnvVar(String),
Path(String),
CacheToken,
None,
}
impl FromStr for TokenSource {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<&str> = s.splitn(2, ':').collect();
match parts[0] {
"literal" => parts
.get(1)
.map(|&value| TokenSource::Literal(value.to_string()))
.ok_or_else(|| "Expected a value for 'literal'".to_string()),
"env" => Ok(TokenSource::EnvVar(
parts
.get(1)
.unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
.to_string(),
)),
"path" => parts
.get(1)
.map(|&value| TokenSource::Path(value.to_string()))
.ok_or_else(|| "Expected a value for 'path'".to_string()),
"cache" => Ok(TokenSource::CacheToken),
"none" => Ok(TokenSource::None),
_ => Err("Invalid token source format".to_string()),
}
}
}
impl fmt::Display for TokenSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TokenSource::Literal(value) => write!(f, "literal:{value}"),
TokenSource::EnvVar(value) => write!(f, "env:{value}"),
TokenSource::Path(value) => write!(f, "path:{value}"),
TokenSource::CacheToken => write!(f, "cache"),
TokenSource::None => write!(f, "none"),
}
}
}
#[derive(Clone, Default, derive_more::From, strum::Display)]
pub enum ModelKind {
#[default]
#[strum(to_string = "normal (no adapters)")]
Normal,
#[strum(to_string = "gguf quantized from {quant} (no adapters)")]
GgufQuantized { quant: QuantizationKind },
#[strum(to_string = "{adapter}")]
Adapter { adapter: AdapterKind },
#[strum(to_string = "{adapter}, gguf quantized from {quant}")]
GgufAdapter {
adapter: AdapterKind,
quant: QuantizationKind,
},
#[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
Speculative {
target: Box<ModelKind>,
draft: Box<ModelKind>,
},
#[strum(to_string = "anymoe: target: `{target}`")]
AnyMoe { target: Box<ModelKind> },
}
#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
#[strum(serialize_all = "kebab-case")]
pub enum QuantizationKind {
Ggml,
Gguf,
Gptq,
}
#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
#[strum(serialize_all = "kebab-case")]
pub enum AdapterKind {
Lora,
XLora,
}
pub trait PrettyName: strum::EnumMessage + ToString {
fn pretty_name(&self) -> String {
match self.get_documentation() {
Some(s) => s.to_string(),
None => self.to_string(),
}
}
}
impl PrettyName for AdapterKind {}
impl PrettyName for QuantizationKind {}
impl ModelKind {
pub fn is_quantized(&self) -> bool {
self.quantized_kind().iter().any(|q| q.is_some())
}
pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
}
pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
use ModelKind::*;
match self {
Normal | Adapter { .. } => vec![None],
GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
Speculative { target, draft } => {
let t = *target.clone();
let d = *draft.clone();
[t.quantized_kind(), d.quantized_kind()].concat()
}
AnyMoe { target } => target.quantized_kind(),
}
}
pub fn is_adapted(&self) -> bool {
self.adapted_kind().iter().any(|a| a.is_some())
}
pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
}
pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
use ModelKind::*;
match self {
Normal | GgufQuantized { .. } => vec![None],
Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
Speculative { target, draft } => {
let t = *target.clone();
let d = *draft.clone();
[t.adapted_kind(), d.adapted_kind()].concat()
}
AnyMoe { target } => target.adapted_kind(),
}
}
}
#[derive(Deserialize)]
pub struct QuantizationConfigShim {
quantization_config: Option<QuantizedConfig>,
}
impl QuantizationConfigShim {
pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
let QuantizationConfigShim {
quantization_config,
} = serde_json::from_str(config)?;
if let Some(quantization_config) = quantization_config {
Ok(quantization_config.pack_factor(dtype))
} else {
Ok(1)
}
}
}
pub trait DeviceMappedModelLoader {
fn non_mapped_max_act_size_elems(
&self,
config: &str,
params: &AutoDeviceMapParams,
) -> Result<usize>;
fn mapped_max_act_size_elems(
&self,
config: &str,
params: &AutoDeviceMapParams,
) -> Result<usize>;
fn non_mapped_size_in_bytes(
&self,
config: &str,
dtype: DType,
weight_pack_factor: usize,
matformer_config: Option<&MatformerSliceConfig>,
) -> Result<usize>;
fn layer_sizes_in_bytes(
&self,
config: &str,
dtype: DType,
weight_pack_factor: usize,
matformer_config: Option<&MatformerSliceConfig>,
) -> Result<Vec<usize>>;
fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
None
}
fn num_layers(&self, config: &str) -> Result<usize>;
fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
#[allow(clippy::too_many_arguments)]
fn get_device_layers(
&self,
config: &str,
num_layers: usize,
layer_sizes_in_bytes: Vec<usize>,
non_mapped_size_in_bytes: usize,
total_model_size_in_bytes: usize,
devices: &[Device],
dtype: DType,
params: &AutoDeviceMapParams,
paged_attn_config: Option<&PagedAttentionConfig>,
) -> Result<DeviceMapMetadata>
where
Self: Sized,
{
auto_device_map::get_device_layers(
self,
config,
num_layers,
layer_sizes_in_bytes,
non_mapped_size_in_bytes,
total_model_size_in_bytes,
devices,
dtype,
params,
paged_attn_config,
)
}
}
pub trait Loader: Send + Sync {
#[allow(clippy::type_complexity, clippy::too_many_arguments)]
fn load_model_from_hf(
&self,
revision: Option<String>,
token_source: TokenSource,
dtype: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapSetting,
in_situ_quant: Option<IsqType>,
paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
#[allow(
clippy::type_complexity,
clippy::too_many_arguments,
clippy::borrowed_box
)]
fn load_model_from_path(
&self,
paths: &Box<dyn ModelPaths>,
dtype: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapSetting,
in_situ_quant: Option<IsqType>,
paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
fn get_id(&self) -> String;
fn get_kind(&self) -> ModelKind;
}