use super::varbuilder_utils::{
from_mmaped_safetensors, load_preload_adapters, DeviceForLoadTensor,
};
use anyhow::Result;
use candle_core::{quantized::ggml_file, DType};
use mistralrs_quant::ShardedVarBuilder;
use std::{collections::HashMap, path::PathBuf, sync::Arc};
use crate::{
device_map::DeviceMapper,
gguf::Content,
lora::{LoraConfig, Ordering},
paged_attention::AttentionImplementation,
pipeline::{AdapterPaths, ModelPaths},
xlora_models::XLoraConfig,
};
#[derive(derive_more::From)]
pub struct FileGGML {
pub ct: ggml_file::Content,
pub gqa: usize,
pub dtype: DType,
}
#[derive(derive_more::From)]
pub struct Device<'a> {
device: &'a candle_core::Device,
pub mapper: Box<dyn DeviceMapper + Send + Sync>,
}
pub struct Adapter<'a> {
pub xlora_config: Option<XLoraConfig>,
pub lora_config: &'a [((String, String), LoraConfig)],
pub vb: ShardedVarBuilder,
pub ordering: &'a Ordering,
pub preload_adapters: Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
}
impl<'a> Adapter<'a> {
#[allow(clippy::borrowed_box)]
pub fn try_new<'b: 'a>(
paths: &'b Box<dyn ModelPaths>,
device: &'b candle_core::Device,
silent: bool,
is_xlora: bool,
) -> Result<Self> {
let AdapterPaths::XLora {
adapter_configs,
adapter_safetensors,
classifier_path,
xlora_order,
xlora_config,
lora_preload_adapter_info,
} = paths.get_adapter_paths()
else {
todo!()
};
let lora_config = adapter_configs.as_ref().unwrap();
let ordering = xlora_order.as_ref().unwrap();
let preload_adapters = load_preload_adapters(
lora_preload_adapter_info,
candle_core::DType::F32,
device,
silent,
)?;
let mut xlora_paths: Vec<PathBuf> = vec![];
if is_xlora {
xlora_paths = vec![classifier_path.as_ref().unwrap().to_path_buf()];
}
let vb = from_mmaped_safetensors(
xlora_paths,
adapter_safetensors
.as_ref()
.unwrap()
.iter()
.map(|(_, x)| (*x).to_owned())
.collect::<Vec<_>>(),
Some(candle_core::DType::F32),
device,
vec![None],
silent,
None,
|_| true,
Arc::new(|_| DeviceForLoadTensor::Base),
)?;
Ok(Self {
lora_config,
xlora_config: xlora_config.clone(),
vb,
ordering,
preload_adapters,
})
}
}
pub struct ParamsGGML(pub FileGGML);
pub struct ParamsGGUF<'a, R: std::io::Seek + std::io::Read>(
pub Content<'a, R>,
pub Device<'a>,
pub AttentionImplementation,
pub DType,
);
pub struct NoAdapter {}
pub trait QuantParams {}
impl QuantParams for ParamsGGML {}
impl<R: std::io::Seek + std::io::Read> QuantParams for ParamsGGUF<'_, R> {}
pub trait MaybeAdapter {}
impl MaybeAdapter for Adapter<'_> {}
impl MaybeAdapter for NoAdapter {}
#[derive(derive_more::From)]
pub struct Config<Q: QuantParams, A: MaybeAdapter> {
pub quant: Q,
pub adapter: A,
}
#[allow(clippy::large_enum_variant)]
#[derive(variantly::Variantly)]
pub enum ModelParams<'a, Q>
where
Q: QuantParams,
{
Quantized(Config<Q, NoAdapter>),
Adapted(Config<Q, Adapter<'a>>),
}
impl<'a, Q: QuantParams> ModelParams<'a, Q> {
pub fn new<'b: 'a>(quant: Q, adapter: Option<Adapter<'b>>) -> Self {
match adapter {
None => Self::Quantized((quant, NoAdapter {}).into()),
Some(a) => Self::Adapted((quant, a).into()),
}
}
}
pub trait FromGGML {
fn from_ggml(
ct: ggml_file::Content,
gqa: usize,
dtype: DType,
) -> Result<Self, candle_core::Error>
where
Self: Sized;
}
pub trait FromGGUF {
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: Content<'_, R>,
device: &candle_core::Device,
mapper: Box<dyn DeviceMapper + Send + Sync>,
attention_mechanism: AttentionImplementation,
dtype: DType,
) -> Result<Self, candle_core::Error>
where
Self: Sized;
}
pub trait FromAdapterGGML {
#[allow(clippy::too_many_arguments)]
fn from_ggml(
ct: ggml_file::Content,
gqa: usize,
lora_config: &[((String, String), LoraConfig)],
vb: &ShardedVarBuilder,
ordering: &Ordering,
xlora_config: Option<XLoraConfig>,
preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
dtype: DType,
) -> Result<Self, candle_core::Error>
where
Self: Sized;
}
pub trait FromAdapterGGUF {
#[allow(clippy::too_many_arguments)]
fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: Content<'_, R>,
device: &candle_core::Device,
lora_config: &[((String, String), LoraConfig)],
vb: &ShardedVarBuilder,
ordering: &Ordering,
xlora_config: Option<XLoraConfig>,
mapper: Box<dyn DeviceMapper + Send + Sync>,
preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
dtype: DType,
) -> Result<Self, candle_core::Error>
where
Self: Sized;
}
impl Config<ParamsGGML, NoAdapter> {
pub fn try_into_model<T: FromGGML>(self) -> Result<T, candle_core::Error> {
let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
T::from_ggml(ct, gqa, dtype)
}
}
impl Config<ParamsGGML, Adapter<'_>> {
pub fn try_into_model<T: FromAdapterGGML>(self) -> Result<T, candle_core::Error> {
let ParamsGGML(FileGGML { ct, gqa, dtype }) = self.quant;
let Adapter {
xlora_config,
lora_config,
vb,
ordering,
preload_adapters,
} = self.adapter;
T::from_ggml(
ct,
gqa,
lora_config,
&vb,
ordering,
xlora_config,
&preload_adapters,
dtype,
)
}
}
impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, NoAdapter> {
pub fn try_into_model<T: FromGGUF>(self) -> Result<T, candle_core::Error> {
let ParamsGGUF(ct, Device { device, mapper }, attention_implementation, dtype) = self.quant;
T::from_gguf(ct, device, mapper, attention_implementation, dtype)
}
}
impl<R: std::io::Seek + std::io::Read> Config<ParamsGGUF<'_, R>, Adapter<'_>> {
pub fn try_into_model<T: FromAdapterGGUF>(self) -> Result<T, candle_core::Error> {
let ParamsGGUF(ct, Device { device, mapper }, _attention_implementation, dtype) =
self.quant;
let Adapter {
xlora_config,
lora_config,
vb,
ordering,
preload_adapters,
} = self.adapter;
T::from_gguf(
ct,
device,
lora_config,
&vb,
ordering,
xlora_config,
mapper,
&preload_adapters,
dtype,
)
}
}
use crate::{
models::quantized_llama::ModelWeights as QLlama,
models::quantized_phi2::ModelWeights as QPhi,
models::quantized_phi3::ModelWeights as QPhi3,
models::quantized_qwen::ModelWeights as QQwen,
models::quantized_qwen3::ModelWeights as QQwen3,
models::quantized_qwen3_moe::ModelWeights as QQwen3MoE,
models::quantized_starcoder2::ModelWeights as QStarcoder2,
xlora_models::{XLoraQLlama, XLoraQPhi3},
};
use akin::akin;
impl TryFrom<ModelParams<'_, ParamsGGML>> for QLlama {
type Error = candle_core::Error;
fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
let config = params.expect_quantized("`Config` should be GGML Quantized");
config.try_into_model()
}
}
impl TryFrom<ModelParams<'_, ParamsGGML>> for XLoraQLlama {
type Error = candle_core::Error;
fn try_from(params: ModelParams<'_, ParamsGGML>) -> Result<Self, Self::Error> {
let config = params.expect_adapted("`Config` should be GGML Quantized with an Adapter");
config.try_into_model()
}
}
akin! {
let &models_gguf = [QLlama, QPhi, QPhi3, QStarcoder2, QQwen, QQwen3, QQwen3MoE];
impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf {
type Error = candle_core::Error;
fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
let config = params.expect_quantized("`Config` should be GGUF Quantized");
config.try_into_model()
}
}
}
akin! {
let &models_gguf_a = [XLoraQLlama, XLoraQPhi3];
impl<R: std::io::Seek + std::io::Read> TryFrom<ModelParams<'_, ParamsGGUF<'_, R>>> for *models_gguf_a {
type Error = candle_core::Error;
fn try_from(params: ModelParams<'_, ParamsGGUF<'_, R>>) -> Result<Self, Self::Error> {
let config = params.expect_adapted("`Config` should be GGUF Quantized with an Adapter");
config.try_into_model()
}
}
}