chat_mistralrs/
builder.rs1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use chat_core::error::{ChatError, ChatFailure};
5use chat_core::types::provider_meta::ProviderMeta;
6use mistralrs::{GgufModelBuilder, IsqType, ModelBuilder, MultimodalModelBuilder};
7
8use crate::client::MistralRsClient;
9
10pub struct WithoutModel;
12pub struct WithModel;
14
15#[derive(Debug, Clone, Copy, Default)]
19pub enum DeviceChoice {
20 #[default]
21 Auto,
22 Cpu,
23 Cuda,
25 CudaOrdinal(usize),
27 Metal,
29}
30
31pub struct MistralRsBuilder<M = WithoutModel> {
38 model_id: Option<String>,
39 gguf_file: Option<String>,
40 tok_model_id: Option<String>,
41 device: DeviceChoice,
42 multimodal: bool,
43 isq: Option<IsqType>,
44 logging: bool,
45 description: Option<String>,
46 _m: PhantomData<M>,
47}
48
49impl Default for MistralRsBuilder<WithoutModel> {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl MistralRsBuilder<WithoutModel> {
56 pub fn new() -> Self {
57 Self {
58 model_id: None,
59 gguf_file: None,
60 tok_model_id: None,
61 device: DeviceChoice::Auto,
62 multimodal: false,
63 isq: None,
64 logging: false,
65 description: None,
66 _m: PhantomData,
67 }
68 }
69
70 pub fn with_model(self, id: impl Into<String>) -> MistralRsBuilder<WithModel> {
73 MistralRsBuilder {
74 model_id: Some(id.into()),
75 gguf_file: self.gguf_file,
76 tok_model_id: self.tok_model_id,
77 device: self.device,
78 multimodal: self.multimodal,
79 isq: self.isq,
80 logging: self.logging,
81 description: self.description,
82 _m: PhantomData,
83 }
84 }
85}
86
87impl<M> MistralRsBuilder<M> {
88 pub fn with_gguf_file(mut self, file: impl Into<String>) -> Self {
92 self.gguf_file = Some(file.into());
93 self
94 }
95
96 pub fn with_tok_model_id(mut self, id: impl Into<String>) -> Self {
100 self.tok_model_id = Some(id.into());
101 self
102 }
103
104 pub fn with_device(mut self, device: DeviceChoice) -> Self {
111 match device {
112 DeviceChoice::Auto | DeviceChoice::Cpu => {}
113 DeviceChoice::Cuda | DeviceChoice::CudaOrdinal(_) | DeviceChoice::Metal => {
114 panic!(
115 "DeviceChoice::{device:?} is not yet wired through to the mistral.rs \
116 loader — use DeviceChoice::Auto (and enable the `metal` or `cuda` \
117 feature at compile time) or DeviceChoice::Cpu"
118 );
119 }
120 }
121 self.device = device;
122 self
123 }
124
125 pub fn with_multimodal(mut self) -> Self {
129 self.multimodal = true;
130 self
131 }
132
133 pub fn with_isq(mut self, isq: IsqType) -> Self {
137 self.isq = Some(isq);
138 self
139 }
140
141 pub fn with_logging(mut self) -> Self {
145 self.logging = true;
146 self
147 }
148
149 pub fn with_description(mut self, d: impl Into<String>) -> Self {
150 self.description = Some(d.into());
151 self
152 }
153}
154
155impl MistralRsBuilder<WithModel> {
156 pub async fn build(self) -> Result<MistralRsClient, ChatFailure> {
162 let model_id = self.model_id.expect("with_model() sets model_id");
163 let force_cpu = matches!(self.device, DeviceChoice::Cpu);
164
165 if self.multimodal && (self.gguf_file.is_some() || self.tok_model_id.is_some()) {
166 return Err(build_failure(
167 "builder",
168 &model_id,
169 anyhow::anyhow!(
170 "with_multimodal() is incompatible with with_gguf_file() / \
171 with_tok_model_id(): the multimodal loader does not consume \
172 GGUF files or a separate tokenizer source. Pick one path."
173 ),
174 ));
175 }
176
177 let model = if self.multimodal {
178 let mut b = MultimodalModelBuilder::new(model_id.clone());
179 if let Some(isq) = self.isq {
180 b = b.with_isq(isq);
181 }
182 if force_cpu {
183 b = b.with_force_cpu();
184 }
185 if self.logging {
186 b = b.with_logging();
187 }
188 b.build()
189 .await
190 .map_err(|e| build_failure("multimodal loader", &model_id, e))?
191 } else if let Some(gguf_file) = self.gguf_file.clone() {
192 let mut b = GgufModelBuilder::new(model_id.clone(), vec![gguf_file]);
193 if let Some(tok) = self.tok_model_id.clone() {
194 b = b.with_tok_model_id(tok);
195 }
196 if force_cpu {
197 b = b.with_force_cpu();
198 }
199 if self.logging {
200 b = b.with_logging();
201 }
202 b.build()
203 .await
204 .map_err(|e| build_failure("GGUF loader", &model_id, e))?
205 } else {
206 let mut b = ModelBuilder::new(model_id.clone());
207 if let Some(isq) = self.isq {
208 b = b.with_isq(isq);
209 }
210 if force_cpu {
211 b = b.with_force_cpu();
212 }
213 if self.logging {
214 b = b.with_logging();
215 }
216 b.build()
217 .await
218 .map_err(|e| build_failure("auto-detect loader", &model_id, e))?
219 };
220
221 let meta = Arc::new(ProviderMeta {
222 description: self.description,
223 ..Default::default()
224 });
225
226 Ok(MistralRsClient {
227 model: Arc::new(model),
228 model_id,
229 meta,
230 })
231 }
232}
233
234fn build_failure(loader: &str, model_id: &str, err: anyhow::Error) -> ChatFailure {
235 ChatFailure::from_err(ChatError::Provider(format!(
236 "mistral.rs {loader} failed to load {model_id}: {err}"
237 )))
238}