1use candle_core::Device;
2use mistralrs_core::*;
3use mistralrs_core::{SearchCallback, Tool, ToolCallback};
4use std::collections::HashMap;
5use std::{
6 ops::{Deref, DerefMut},
7 path::PathBuf,
8 sync::Arc,
9};
10
11use crate::model_builder_trait::{build_model_from_pipeline, build_vision_pipeline};
12use crate::Model;
13
14#[derive(Clone)]
15pub struct VisionModelBuilder {
17 pub(crate) model_id: String,
19 pub(crate) token_source: TokenSource,
20 pub(crate) hf_revision: Option<String>,
21 pub(crate) write_uqff: Option<PathBuf>,
22 pub(crate) from_uqff: Option<Vec<PathBuf>>,
23 pub(crate) calibration_file: Option<PathBuf>,
24 pub(crate) imatrix: Option<PathBuf>,
25 pub(crate) chat_template: Option<String>,
26 pub(crate) jinja_explicit: Option<String>,
27 pub(crate) tokenizer_json: Option<String>,
28 pub(crate) device_mapping: Option<DeviceMapSetting>,
29 pub(crate) max_edge: Option<u32>,
30 pub(crate) hf_cache_path: Option<PathBuf>,
31 pub(crate) search_embedding_model: Option<SearchEmbeddingModel>,
32 pub(crate) search_callback: Option<Arc<SearchCallback>>,
33 pub(crate) tool_callbacks: HashMap<String, Arc<ToolCallback>>,
34 pub(crate) tool_callbacks_with_tools: HashMap<String, ToolCallbackWithTool>,
35 pub(crate) device: Option<Device>,
36 pub(crate) matformer_config_path: Option<PathBuf>,
37 pub(crate) matformer_slice_name: Option<String>,
38
39 pub(crate) topology: Option<Topology>,
41 pub(crate) topology_path: Option<String>,
42 pub(crate) loader_type: Option<VisionLoaderType>,
43 pub(crate) dtype: ModelDType,
44 pub(crate) force_cpu: bool,
45 pub(crate) isq: Option<IsqType>,
46 pub(crate) throughput_logging: bool,
47
48 pub(crate) paged_attn_cfg: Option<PagedAttentionConfig>,
50 pub(crate) max_num_seqs: usize,
51 pub(crate) with_logging: bool,
52 pub(crate) prefix_cache_n: Option<usize>,
53}
54
55impl VisionModelBuilder {
56 pub fn new(model_id: impl ToString) -> Self {
62 Self {
63 model_id: model_id.to_string(),
64 topology: None,
65 topology_path: None,
66 write_uqff: None,
67 from_uqff: None,
68 chat_template: None,
69 tokenizer_json: None,
70 max_edge: None,
71 loader_type: None,
72 dtype: ModelDType::Auto,
73 force_cpu: false,
74 token_source: TokenSource::CacheToken,
75 hf_revision: None,
76 isq: None,
77 max_num_seqs: 32,
78 with_logging: false,
79 device_mapping: None,
80 calibration_file: None,
81 imatrix: None,
82 jinja_explicit: None,
83 throughput_logging: false,
84 paged_attn_cfg: None,
85 hf_cache_path: None,
86 search_embedding_model: None,
87 search_callback: None,
88 tool_callbacks: HashMap::new(),
89 tool_callbacks_with_tools: HashMap::new(),
90 device: None,
91 matformer_config_path: None,
92 matformer_slice_name: None,
93 prefix_cache_n: None,
94 }
95 }
96
97 pub fn with_search(mut self, search_embedding_model: SearchEmbeddingModel) -> Self {
99 self.search_embedding_model = Some(search_embedding_model);
100 self
101 }
102
103 pub fn with_search_callback(mut self, callback: Arc<SearchCallback>) -> Self {
105 self.search_callback = Some(callback);
106 self
107 }
108
109 pub fn with_tool_callback(
110 mut self,
111 name: impl Into<String>,
112 callback: Arc<ToolCallback>,
113 ) -> Self {
114 self.tool_callbacks.insert(name.into(), callback);
115 self
116 }
117
118 pub fn with_tool_callback_and_tool(
121 mut self,
122 name: impl Into<String>,
123 callback: Arc<ToolCallback>,
124 tool: Tool,
125 ) -> Self {
126 let name = name.into();
127 self.tool_callbacks_with_tools
128 .insert(name, ToolCallbackWithTool { callback, tool });
129 self
130 }
131
132 pub fn with_throughput_logging(mut self) -> Self {
134 self.throughput_logging = true;
135 self
136 }
137
138 pub fn with_jinja_explicit(mut self, jinja_explicit: String) -> Self {
140 self.jinja_explicit = Some(jinja_explicit);
141 self
142 }
143
144 pub fn with_topology(mut self, topology: Topology) -> Self {
146 self.topology = Some(topology);
147 self
148 }
149
150 pub fn with_topology_from_path<P: AsRef<std::path::Path>>(
153 mut self,
154 path: P,
155 ) -> anyhow::Result<Self> {
156 let path_str = path.as_ref().to_string_lossy().to_string();
157 self.topology = Some(Topology::from_path(&path)?);
158 self.topology_path = Some(path_str);
159 Ok(self)
160 }
161
162 pub fn with_chat_template(mut self, chat_template: impl ToString) -> Self {
164 self.chat_template = Some(chat_template.to_string());
165 self
166 }
167
168 pub fn with_tokenizer_json(mut self, tokenizer_json: impl ToString) -> Self {
170 self.tokenizer_json = Some(tokenizer_json.to_string());
171 self
172 }
173
174 pub fn with_loader_type(mut self, loader_type: VisionLoaderType) -> Self {
177 self.loader_type = Some(loader_type);
178 self
179 }
180
181 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
183 self.dtype = dtype;
184 self
185 }
186
187 pub fn with_force_cpu(mut self) -> Self {
189 self.force_cpu = true;
190 self
191 }
192
193 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
195 self.token_source = token_source;
196 self
197 }
198
199 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
201 self.hf_revision = Some(revision.to_string());
202 self
203 }
204
205 pub fn with_isq(mut self, isq: IsqType) -> Self {
207 self.isq = Some(isq);
208 self
209 }
210
211 pub fn with_calibration_file(mut self, path: PathBuf) -> Self {
213 self.calibration_file = Some(path);
214 self
215 }
216
217 pub fn with_paged_attn(
222 mut self,
223 paged_attn_cfg: impl FnOnce() -> anyhow::Result<PagedAttentionConfig>,
224 ) -> anyhow::Result<Self> {
225 if paged_attn_supported() {
226 self.paged_attn_cfg = Some(paged_attn_cfg()?);
227 } else {
228 self.paged_attn_cfg = None;
229 }
230 Ok(self)
231 }
232
233 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
235 self.max_num_seqs = max_num_seqs;
236 self
237 }
238
239 pub fn with_prefix_cache_n(mut self, n_seqs: Option<usize>) -> Self {
241 self.prefix_cache_n = n_seqs;
242 self
243 }
244
245 pub fn with_logging(mut self) -> Self {
247 self.with_logging = true;
248 self
249 }
250
251 pub fn with_device_mapping(mut self, device_mapping: DeviceMapSetting) -> Self {
253 self.device_mapping = Some(device_mapping);
254 self
255 }
256
257 #[deprecated(
258 note = "Use `UqffTextModelBuilder` to load a UQFF model instead of the generic `from_uqff`"
259 )]
260 pub fn from_uqff(mut self, path: Vec<PathBuf>) -> Self {
268 self.from_uqff = Some(path);
269 self
270 }
271
272 pub fn from_max_edge(mut self, max_edge: u32) -> Self {
275 self.max_edge = Some(max_edge);
276 self
277 }
278
279 pub fn write_uqff(mut self, path: PathBuf) -> Self {
290 self.write_uqff = Some(path);
291 self
292 }
293
294 pub fn from_hf_cache_pathf(mut self, hf_cache_path: PathBuf) -> Self {
296 self.hf_cache_path = Some(hf_cache_path);
297 self
298 }
299
300 pub fn with_device(mut self, device: Device) -> Self {
302 self.device = Some(device);
303 self
304 }
305
306 pub fn with_matformer_config_path(mut self, path: PathBuf) -> Self {
308 self.matformer_config_path = Some(path);
309 self
310 }
311
312 pub fn with_matformer_slice_name(mut self, name: String) -> Self {
314 self.matformer_slice_name = Some(name);
315 self
316 }
317
318 pub async fn build(self) -> anyhow::Result<Model> {
319 let (pipeline, scheduler_config, add_model_config) = build_vision_pipeline(self).await?;
320 Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
321 }
322}
323
324#[derive(Clone)]
325pub struct UqffVisionModelBuilder(VisionModelBuilder);
328
329impl UqffVisionModelBuilder {
330 pub fn new(model_id: impl ToString, uqff_file: Vec<PathBuf>) -> Self {
335 let mut inner = VisionModelBuilder::new(model_id);
336 inner.from_uqff = Some(uqff_file);
337 Self(inner)
338 }
339
340 pub async fn build(self) -> anyhow::Result<Model> {
341 self.0.build().await
342 }
343
344 pub fn into_inner(self) -> VisionModelBuilder {
346 self.0
347 }
348}
349
350impl Deref for UqffVisionModelBuilder {
351 type Target = VisionModelBuilder;
352
353 fn deref(&self) -> &Self::Target {
354 &self.0
355 }
356}
357
358impl DerefMut for UqffVisionModelBuilder {
359 fn deref_mut(&mut self) -> &mut Self::Target {
360 &mut self.0
361 }
362}
363
364impl From<UqffVisionModelBuilder> for VisionModelBuilder {
365 fn from(value: UqffVisionModelBuilder) -> Self {
366 value.0
367 }
368}