use candle_core::Device;
use mistralrs_core::*;
use mistralrs_core::{SearchCallback, Tool, ToolCallback};
use crate::{IsqBits, IsqSetting};
use std::collections::HashMap;
use std::{
ops::{Deref, DerefMut},
path::PathBuf,
sync::Arc,
};
use crate::model_builder_trait::{build_model_from_pipeline, build_multimodal_pipeline};
use crate::Model;
#[derive(Clone)]
pub struct MultimodalModelBuilder {
pub(crate) model_id: String,
pub(crate) token_source: TokenSource,
pub(crate) hf_revision: Option<String>,
pub(crate) write_uqff: Option<PathBuf>,
pub(crate) from_uqff: Option<Vec<PathBuf>>,
pub(crate) calibration_file: Option<PathBuf>,
pub(crate) imatrix: Option<PathBuf>,
pub(crate) chat_template: Option<String>,
pub(crate) jinja_explicit: Option<String>,
pub(crate) tokenizer_json: Option<String>,
pub(crate) device_mapping: Option<DeviceMapSetting>,
pub(crate) max_edge: Option<u32>,
pub(crate) hf_cache_path: Option<PathBuf>,
pub(crate) search_embedding_model: Option<SearchEmbeddingModel>,
pub(crate) search_callback: Option<Arc<SearchCallback>>,
pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
pub(crate) device: Option<Device>,
pub(crate) matformer_config_path: Option<PathBuf>,
pub(crate) matformer_slice_name: Option<String>,
pub(crate) organization: IsqOrganization,
pub(crate) topology: Option<Topology>,
pub(crate) topology_path: Option<String>,
pub(crate) loader_type: Option<MultimodalLoaderType>,
pub(crate) dtype: ModelDType,
pub(crate) force_cpu: bool,
pub(crate) isq: Option<IsqSetting>,
pub(crate) throughput_logging: bool,
pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
pub(crate) max_num_seqs: usize,
pub(crate) with_logging: bool,
pub(crate) prefix_cache_n: Option<usize>,
}
impl MultimodalModelBuilder {
pub fn new(model_id: impl ToString) -> Self {
Self {
model_id: model_id.to_string(),
topology: None,
topology_path: None,
write_uqff: None,
from_uqff: None,
chat_template: None,
tokenizer_json: None,
max_edge: None,
loader_type: None,
dtype: ModelDType::Auto,
force_cpu: false,
token_source: TokenSource::CacheToken,
hf_revision: None,
isq: None,
max_num_seqs: 32,
with_logging: false,
device_mapping: None,
calibration_file: None,
imatrix: None,
jinja_explicit: None,
throughput_logging: false,
paged_attn_cfg: None,
hf_cache_path: None,
search_embedding_model: None,
search_callback: None,
tool_callbacks: HashMap::new(),
tool_callbacks_with_tools: HashMap::new(),
device: None,
matformer_config_path: None,
matformer_slice_name: None,
organization: IsqOrganization::Default,
prefix_cache_n: None,
}
}
common_builder_methods!();
pub fn with_loader_type(mut self, loader_type: MultimodalLoaderType) -> Self {
self.loader_type = Some(loader_type);
self
}
#[deprecated(
note = "Use `UqffMultimodalModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
)]
pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
self.from_uqff = Some(path);
self
}
pub fn with_max_edge(mut self, max_edge: u32) -> Self {
self.max_edge = Some(max_edge);
self
}
pub async fn build(self) -> anyhow::Result<Model> {
let (pipeline, scheduler_config, add_model_config) =
build_multimodal_pipeline(self).await?;
Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
}
}
#[derive(Clone)]
pub struct UqffMultimodalModelBuilder(MultimodalModelBuilder);
impl UqffMultimodalModelBuilder {
pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
let mut inner = MultimodalModelBuilder::new(model_id);
inner.from_uqff = Some(uqff_file);
Self(inner)
}
pub async fn build(self) -> anyhow::Result<Model> {
self.0.build().await
}
pub fn into_inner(self) -> MultimodalModelBuilder {
self.0
}
}
impl Deref for UqffMultimodalModelBuilder {
type Target = MultimodalModelBuilder;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for UqffMultimodalModelBuilder {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<UqffMultimodalModelBuilder> for MultimodalModelBuilder {
fn from(value: UqffMultimodalModelBuilder) -> Self {
value.0
}
}