Skip to main content

mistralrs/
vision_model.rs

1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6    ops::{Deref, DerefMut},
7    path::PathBuf,
8    sync::Arc,
9};
10
11use crate::model_builder_trait::{build_model_from_pipeline, build_vision_pipeline};
12use crate::Model;
13
14#[derive(Clone)]
15/// Configure a vision model with the various parameters for loading, running, and other inference behaviors.
16pub struct VisionModelBuilder {
17    // Loading model
18    pub(crate) model_id: String,
19    pub(crate) token_source: TokenSource,
20    pub(crate) hf_revision: Option<String>,
21    pub(crate) write_uqff: Option<PathBuf>,
22    pub(crate) from_uqff: Option<Vec<PathBuf>>,
23    pub(crate) calibration_file: Option<PathBuf>,
24    pub(crate) imatrix: Option<PathBuf>,
25    pub(crate) chat_template: Option<String>,
26    pub(crate) jinja_explicit: Option<String>,
27    pub(crate) tokenizer_json: Option<String>,
28    pub(crate) device_mapping: Option<DeviceMapSetting>,
29    pub(crate) max_edge: Option<u32>,
30    pub(crate) hf_cache_path: Option<PathBuf>,
31    pub(crate) search_embedding_model: Option<SearchEmbeddingModel>,
32    pub(crate) search_callback: Option<Arc<SearchCallback>>,
33    pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
34    pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
35    pub(crate) device: Option<Device>,
36    pub(crate) matformer_config_path: Option<PathBuf>,
37    pub(crate) matformer_slice_name: Option<String>,
38
39    // Model running
40    pub(crate) topology: Option<Topology>,
41    pub(crate) topology_path: Option<String>,
42    pub(crate) loader_type: Option<VisionLoaderType>,
43    pub(crate) dtype: ModelDType,
44    pub(crate) force_cpu: bool,
45    pub(crate) isq: Option<IsqType>,
46    pub(crate) throughput_logging: bool,
47
48    // Other things
49    pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
50    pub(crate) max_num_seqs: usize,
51    pub(crate) with_logging: bool,
52    pub(crate) prefix_cache_n: Option<usize>,
53}
54
55impl VisionModelBuilder {
56    /// A few defaults are applied here:
57    /// - Token source is from the cache (.cache/huggingface/token)
58    /// - Maximum number of sequences running is 32
59    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
60    /// - By default, web searching compatible with the OpenAI `web_search_options` setting is disabled.
61    pub fn new(model_id: impl ToString) -> Self {
62        Self {
63            model_id: model_id.to_string(),
64            topology: None,
65            topology_path: None,
66            write_uqff: None,
67            from_uqff: None,
68            chat_template: None,
69            tokenizer_json: None,
70            max_edge: None,
71            loader_type: None,
72            dtype: ModelDType::Auto,
73            force_cpu: false,
74            token_source: TokenSource::CacheToken,
75            hf_revision: None,
76            isq: None,
77            max_num_seqs: 32,
78            with_logging: false,
79            device_mapping: None,
80            calibration_file: None,
81            imatrix: None,
82            jinja_explicit: None,
83            throughput_logging: false,
84            paged_attn_cfg: None,
85            hf_cache_path: None,
86            search_embedding_model: None,
87            search_callback: None,
88            tool_callbacks: HashMap::new(),
89            tool_callbacks_with_tools: HashMap::new(),
90            device: None,
91            matformer_config_path: None,
92            matformer_slice_name: None,
93            prefix_cache_n: None,
94        }
95    }
96
97    /// Enable searching compatible with the OpenAI `web_search_options` setting. This loads the selected search embedding reranker (EmbeddingGemma by default).
98    pub fn with_search(mut self, search_embedding_model: SearchEmbeddingModel) -> Self {
99        self.search_embedding_model = Some(search_embedding_model);
100        self
101    }
102
103    /// Override the search function used when `web_search_options` is enabled.
104    pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
105        self.search_callback = Some(callback);
106        self
107    }
108
109    pub fn with_tool_callback(
110        mut self,
111        name: impl Into<String>,
112        callback: Arc<ToolCallback>,
113    ) -> Self {
114        self.tool_callbacks.insert(name.into(), callback);
115        self
116    }
117
118    /// Register a callback with an associated Tool definition that will be automatically
119    /// added to requests when tool callbacks are active.
120    pub fn with_tool_callback_and_tool(
121        mut self,
122        name: impl Into<String>,
123        callback: Arc<ToolCallback>,
124        tool: Tool,
125    ) -> Self {
126        let name = name.into();
127        self.tool_callbacks_with_tools
128            .insert(name, ToolCallbackWithTool { callback, tool });
129        self
130    }
131
132    /// Enable runner throughput logging.
133    pub fn with_throughput_logging(mut self) -> Self {
134        self.throughput_logging = true;
135        self
136    }
137
138    /// Explicit JINJA chat template file (.jinja) to be used. If specified, this overrides all other chat templates.
139    pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
140        self.jinja_explicit = Some(jinja_explicit);
141        self
142    }
143
144    /// Set the model topology for use during loading. If there is an overlap, the topology type is used over the ISQ type.
145    pub fn with_topology(mut self, topology: Topology) -> Self {
146        self.topology = Some(topology);
147        self
148    }
149
150    /// Set the model topology from a path. This preserves the path for unload/reload support.
151    /// If there is an overlap, the topology type is used over the ISQ type.
152    pub fn with_topology_from_path<P: AsRef<std::path::Path>>(
153        mut self,
154        path: P,
155    ) -> anyhow::Result<Self> {
156        let path_str = path.as_ref().to_string_lossy().to_string();
157        self.topology = Some(Topology::from_path(&path)?);
158        self.topology_path = Some(path_str);
159        Ok(self)
160    }
161
162    /// Literal Jinja chat template OR Path (ending in `.json`) to one.
163    pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
164        self.chat_template = Some(chat_template.to_string());
165        self
166    }
167
168    /// Path to a discrete `tokenizer.json` file.
169    pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
170        self.tokenizer_json = Some(tokenizer_json.to_string());
171        self
172    }
173
174    /// Manually set the model loader type. Otherwise, it will attempt to automatically
175    /// determine the loader type.
176    pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
177        self.loader_type = Some(loader_type);
178        self
179    }
180
181    /// Load the model in a certain dtype.
182    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
183        self.dtype = dtype;
184        self
185    }
186
187    /// Force usage of the CPU device. Do not use PagedAttention with this.
188    pub fn with_force_cpu(mut self) -> Self {
189        self.force_cpu = true;
190        self
191    }
192
193    /// Source of the Hugging Face token.
194    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
195        self.token_source = token_source;
196        self
197    }
198
199    /// Set the revision to use for a Hugging Face remote model.
200    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
201        self.hf_revision = Some(revision.to_string());
202        self
203    }
204
205    /// Use ISQ of a certain type. If there is an overlap, the topology type is used over the ISQ type.
206    pub fn with_isq(mut self, isq: IsqType) -> Self {
207        self.isq = Some(isq);
208        self
209    }
210
211    /// Utilise this calibration_file file during ISQ
212    pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
213        self.calibration_file = Some(path);
214        self
215    }
216
217    /// Enable PagedAttention. Configure PagedAttention with a [`PagedAttentionConfig`] object, which
218    /// can be created with sensible values with a [`PagedAttentionMetaBuilder`](crate::PagedAttentionMetaBuilder).
219    ///
220    /// If PagedAttention is not supported (query with [`paged_attn_supported`]), this will do nothing.
221    pub fn with_paged_attn(
222        mut self,
223        paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
224    ) -> anyhow::Result<Self> {
225        if paged_attn_supported() {
226            self.paged_attn_cfg = Some(paged_attn_cfg()?);
227        } else {
228            self.paged_attn_cfg = None;
229        }
230        Ok(self)
231    }
232
233    /// Set the maximum number of sequences which can be run at once.
234    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
235        self.max_num_seqs = max_num_seqs;
236        self
237    }
238
239    /// Set the number of sequences to hold in the prefix cache. Set to `None` to disable the prefix cacher.
240    pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
241        self.prefix_cache_n = n_seqs;
242        self
243    }
244
245    /// Enable logging.
246    pub fn with_logging(mut self) -> Self {
247        self.with_logging = true;
248        self
249    }
250
251    /// Provide metadata to initialize the device mapper.
252    pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
253        self.device_mapping = Some(device_mapping);
254        self
255    }
256
257    #[deprecated(
258        note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
259    )]
260    /// Path to read a `.uqff` file from. Other necessary configuration files must be present at this location.
261    ///
262    /// For example, these include:
263    /// - `residual.safetensors`
264    /// - `tokenizer.json`
265    /// - `config.json`
266    /// - More depending on the model
267    pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
268        self.from_uqff = Some(path);
269        self
270    }
271
272    /// Automatically resize and pad images to this maximum edge length. Aspect ratio is preserved.
273    /// This is only supported on the Qwen2-VL and Idefics 2 models. Others handle this internally.
274    pub fn from_max_edge(mut self, max_edge: u32) -> Self {
275        self.max_edge = Some(max_edge);
276        self
277    }
278
279    /// Path to write a `.uqff` file to and serialize the other necessary files.
280    ///
281    /// The parent (part of the path excluding the filename) will determine where any other files
282    /// serialized are written to.
283    ///
284    /// For example, these include:
285    /// - `residual.safetensors`
286    /// - `tokenizer.json`
287    /// - `config.json`
288    /// - More depending on the model
289    pub fn write_uqff(mut self, path: PathBuf) -> Self {
290        self.write_uqff = Some(path);
291        self
292    }
293
294    /// Cache path for Hugging Face models downloaded locally
295    pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
296        self.hf_cache_path = Some(hf_cache_path);
297        self
298    }
299
300    /// Set the main device to load this model onto. Automatic device mapping will be performed starting with this device.
301    pub fn with_device(mut self, device: Device) -> Self {
302        self.device = Some(device);
303        self
304    }
305
306    /// Path to a Matryoshka Transformer configuration CSV file.
307    pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
308        self.matformer_config_path = Some(path);
309        self
310    }
311
312    /// Name of the slice to use from the Matryoshka Transformer configuration.
313    pub fn with_matformer_slice_name(mut self, name: String) -> Self {
314        self.matformer_slice_name = Some(name);
315        self
316    }
317
318    pub async fn build(self) -> anyhow::Result<Model> {
319        let (pipeline, scheduler_config, add_model_config) = build_vision_pipeline(self).await?;
320        Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
321    }
322}
323
324#[derive(Clone)]
325/// Configure a UQFF text model with the various parameters for loading, running, and other inference behaviors.
326/// This wraps and implements `DerefMut` for the VisionModelBuilder, so users should take care to not call UQFF-related methods.
327pub struct UqffVisionModelBuilder(VisionModelBuilder);
328
329impl UqffVisionModelBuilder {
330    /// A few defaults are applied here:
331    /// - Token source is from the cache (.cache/huggingface/token)
332    /// - Maximum number of sequences running is 32
333    /// - Automatic device mapping with model defaults according to `AutoDeviceMapParams`
334    pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
335        let mut inner = VisionModelBuilder::new(model_id);
336        inner.from_uqff = Some(uqff_file);
337        Self(inner)
338    }
339
340    pub async fn build(self) -> anyhow::Result<Model> {
341        self.0.build().await
342    }
343
344    /// This wraps the VisionModelBuilder, so users should take care to not call UQFF-related methods.
345    pub fn into_inner(self) -> VisionModelBuilder {
346        self.0
347    }
348}
349
350impl Deref for UqffVisionModelBuilder {
351    type Target = VisionModelBuilder;
352
353    fn deref(&self) -> &Self::Target {
354        &self.0
355    }
356}
357
358impl DerefMut for UqffVisionModelBuilder {
359    fn deref_mut(&mut self) -> &mut Self::Target {
360        &mut self.0
361    }
362}
363
364impl From<UqffVisionModelBuilder> for VisionModelBuilder {
365    fn from(value: UqffVisionModelBuilder) -> Self {
366        value.0
367    }
368}