Skip to main content

chat_mistralrs/
builder.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use chat_core::error::{ChatError, ChatFailure};
5use chat_core::types::provider_meta::ProviderMeta;
6use mistralrs::{GgufModelBuilder, IsqType, ModelBuilder, MultimodalModelBuilder};
7
8use crate::client::MistralRsClient;
9
10/// Typestate marker — no model id set yet, `build()` is not callable.
11pub struct WithoutModel;
12/// Typestate marker — model id is set, `build()` is callable.
13pub struct WithModel;
14
15/// Device on which to run inference. `Auto` picks the best available
16/// backend that was compiled in: Metal on macOS with the `metal` feature,
17/// CUDA on Linux with the `cuda` feature, else CPU.
18#[derive(Debug, Clone, Copy, Default)]
19pub enum DeviceChoice {
20    #[default]
21    Auto,
22    Cpu,
23    /// First CUDA device.
24    Cuda,
25    /// Specific CUDA device by ordinal.
26    CudaOrdinal(usize),
27    /// Apple Metal (macOS only).
28    Metal,
29}
30
31/// Builder for [`MistralRsClient`].
32///
33/// `with_model` sets the Hugging Face repo id (e.g. `Qwen/Qwen2.5-3B-Instruct`).
34/// Calling `with_gguf_file` additionally selects a specific GGUF filename
35/// inside that repo, which switches the loader to the GGUF path; without
36/// it, the auto-detect loader runs against `config.json`.
37pub struct MistralRsBuilder<M = WithoutModel> {
38    model_id: Option<String>,
39    gguf_file: Option<String>,
40    tok_model_id: Option<String>,
41    device: DeviceChoice,
42    multimodal: bool,
43    isq: Option<IsqType>,
44    logging: bool,
45    description: Option<String>,
46    _m: PhantomData<M>,
47}
48
49impl Default for MistralRsBuilder<WithoutModel> {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl MistralRsBuilder<WithoutModel> {
56    pub fn new() -> Self {
57        Self {
58            model_id: None,
59            gguf_file: None,
60            tok_model_id: None,
61            device: DeviceChoice::Auto,
62            multimodal: false,
63            isq: None,
64            logging: false,
65            description: None,
66            _m: PhantomData,
67        }
68    }
69
70    /// Set the Hugging Face repo id. Transitions the builder so `.build()`
71    /// becomes callable.
72    pub fn with_model(self, id: impl Into<String>) -> MistralRsBuilder<WithModel> {
73        MistralRsBuilder {
74            model_id: Some(id.into()),
75            gguf_file: self.gguf_file,
76            tok_model_id: self.tok_model_id,
77            device: self.device,
78            multimodal: self.multimodal,
79            isq: self.isq,
80            logging: self.logging,
81            description: self.description,
82            _m: PhantomData,
83        }
84    }
85}
86
87impl<M> MistralRsBuilder<M> {
88    /// Select a specific `.gguf` file inside the repo and use the GGUF
89    /// loader. Required for GGUF-only repos; omit for safetensors repos
90    /// where the auto-detect loader handles file discovery.
91    pub fn with_gguf_file(mut self, file: impl Into<String>) -> Self {
92        self.gguf_file = Some(file.into());
93        self
94    }
95
96    /// Override the tokenizer source. Useful for GGUF repos that don't
97    /// ship `tokenizer.json` — point this at the original safetensors
98    /// repo (e.g. `Qwen/Qwen2.5-3B-Instruct`).
99    pub fn with_tok_model_id(mut self, id: impl Into<String>) -> Self {
100        self.tok_model_id = Some(id.into());
101        self
102    }
103
104    /// Only [`DeviceChoice::Auto`] and [`DeviceChoice::Cpu`] are wired
105    /// through today. Explicit GPU/Metal selection is not yet implemented —
106    /// passing `Cuda`, `CudaOrdinal`, or `Metal` panics rather than silently
107    /// falling back to auto-detect. For now, use `Auto` (which honours the
108    /// compiled-in backend feature) and select GPU at compile time via the
109    /// `metal` / `cuda` features.
110    pub fn with_device(mut self, device: DeviceChoice) -> Self {
111        match device {
112            DeviceChoice::Auto | DeviceChoice::Cpu => {}
113            DeviceChoice::Cuda | DeviceChoice::CudaOrdinal(_) | DeviceChoice::Metal => {
114                panic!(
115                    "DeviceChoice::{device:?} is not yet wired through to the mistral.rs \
116                     loader — use DeviceChoice::Auto (and enable the `metal` or `cuda` \
117                     feature at compile time) or DeviceChoice::Cpu"
118                );
119            }
120        }
121        self.device = device;
122        self
123    }
124
125    /// Switch the loader to mistral.rs's multimodal path. Required for
126    /// vision / audio models (e.g. Voxtral, Gemma 3n, Phi-4-MM,
127    /// Qwen2.5-VL). Without this, the auto-detect / GGUF loaders run.
128    pub fn with_multimodal(mut self) -> Self {
129        self.multimodal = true;
130        self
131    }
132
133    /// Apply mistral.rs in-situ quantization (ISQ). Useful for fitting
134    /// larger multimodal models into a MacBook's memory, e.g.
135    /// `IsqType::Q4K` for ~4-bit weights.
136    pub fn with_isq(mut self, isq: IsqType) -> Self {
137        self.isq = Some(isq);
138        self
139    }
140
141    /// Enable mistral.rs's built-in loader/download progress logs. Useful
142    /// for first-time runs where weight download can take many minutes —
143    /// without this the process is silent until the model is fully loaded.
144    pub fn with_logging(mut self) -> Self {
145        self.logging = true;
146        self
147    }
148
149    pub fn with_description(mut self, d: impl Into<String>) -> Self {
150        self.description = Some(d.into());
151        self
152    }
153}
154
155impl MistralRsBuilder<WithModel> {
156    /// Load weights and return a ready-to-use client.
157    ///
158    /// This is async and fallible because it actually downloads (on first
159    /// run) and loads the model into memory — multi-second on a warm cache,
160    /// multi-minute on a cold one.
161    pub async fn build(self) -> Result<MistralRsClient, ChatFailure> {
162        let model_id = self.model_id.expect("with_model() sets model_id");
163        let force_cpu = matches!(self.device, DeviceChoice::Cpu);
164
165        if self.multimodal && (self.gguf_file.is_some() || self.tok_model_id.is_some()) {
166            return Err(build_failure(
167                "builder",
168                &model_id,
169                anyhow::anyhow!(
170                    "with_multimodal() is incompatible with with_gguf_file() / \
171                     with_tok_model_id(): the multimodal loader does not consume \
172                     GGUF files or a separate tokenizer source. Pick one path."
173                ),
174            ));
175        }
176
177        let model = if self.multimodal {
178            let mut b = MultimodalModelBuilder::new(model_id.clone());
179            if let Some(isq) = self.isq {
180                b = b.with_isq(isq);
181            }
182            if force_cpu {
183                b = b.with_force_cpu();
184            }
185            if self.logging {
186                b = b.with_logging();
187            }
188            b.build()
189                .await
190                .map_err(|e| build_failure("multimodal loader", &model_id, e))?
191        } else if let Some(gguf_file) = self.gguf_file.clone() {
192            let mut b = GgufModelBuilder::new(model_id.clone(), vec![gguf_file]);
193            if let Some(tok) = self.tok_model_id.clone() {
194                b = b.with_tok_model_id(tok);
195            }
196            if force_cpu {
197                b = b.with_force_cpu();
198            }
199            if self.logging {
200                b = b.with_logging();
201            }
202            b.build()
203                .await
204                .map_err(|e| build_failure("GGUF loader", &model_id, e))?
205        } else {
206            let mut b = ModelBuilder::new(model_id.clone());
207            if let Some(isq) = self.isq {
208                b = b.with_isq(isq);
209            }
210            if force_cpu {
211                b = b.with_force_cpu();
212            }
213            if self.logging {
214                b = b.with_logging();
215            }
216            b.build()
217                .await
218                .map_err(|e| build_failure("auto-detect loader", &model_id, e))?
219        };
220
221        let meta = Arc::new(ProviderMeta {
222            description: self.description,
223            ..Default::default()
224        });
225
226        Ok(MistralRsClient {
227            model: Arc::new(model),
228            model_id,
229            meta,
230        })
231    }
232}
233
234fn build_failure(loader: &str, model_id: &str, err: anyhow::Error) -> ChatFailure {
235    ChatFailure::from_err(ChatError::Provider(format!(
236        "mistral.rs {loader} failed to load {model_id}: {err}"
237    )))
238}