1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result};
4use candle_core::{DType, Device};
5use hf_hub::{
6 api::sync::{ApiBuilder, ApiRepo},
7 Cache, Repo, RepoType,
8};
9use serde::{Deserialize, Serialize};
10
11use crate::device_map::{DeviceLayerMapMetadata, DeviceMapMetadata};
12use crate::model_loader::{get_auto_device_map_params, get_model_dtype};
13use crate::pipeline::{
14 AutoDeviceMapParams, AutoEmbeddingLoader, AutoNormalLoader, AutoVisionLoader,
15 DeviceMappedModelLoader, EmbeddingLoaderType, NormalLoaderType, TokenSource, VisionLoaderType,
16};
17use crate::utils::tokens::get_token;
18use crate::{paged_attn_supported, IsqType, ModelSelected, TryIntoDType, GLOBAL_HF_CACHE};
19
20#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
21#[serde(rename_all = "kebab-case")]
22pub enum TuneProfile {
23 Quality,
24 Balanced,
25 Fast,
26}
27
28#[derive(Debug, Clone)]
29pub struct AutoTuneRequest {
30 pub model: ModelSelected,
31 pub token_source: TokenSource,
32 pub hf_revision: Option<String>,
33 pub force_cpu: bool,
34 pub profile: TuneProfile,
35 pub requested_isq: Option<IsqType>,
36}
37
38#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "kebab-case")]
41pub enum QualityTier {
42 Baseline,
44 NearLossless,
46 Good,
48 Acceptable,
50 Degraded,
52}
53
54#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
56#[serde(rename_all = "kebab-case")]
57pub enum FitStatus {
58 Fits,
60 Hybrid,
62 TooLarge,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct TuneCandidate {
69 pub isq: Option<IsqType>,
71 pub isq_name: String,
73 pub estimated_size_bytes: u64,
75 pub vram_usage_percent: f32,
77 pub max_context_tokens: usize,
79 pub context_is_model_max: bool,
81 pub quality: QualityTier,
83 pub fit_status: FitStatus,
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub device_layers_cli: Option<String>,
88 pub recommended: bool,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct AutoTuneResult {
94 pub model_id: String,
95 pub profile: TuneProfile,
96 pub backend: String,
97 pub candidates: Vec<TuneCandidate>,
99 pub recommended_isq: Option<IsqType>,
101 pub device_layers: Option<Vec<DeviceLayerMapMetadata>>,
103 pub device_layers_cli: Option<String>,
104 pub paged_attn_mode: Option<String>,
105 pub recommended_command: String,
107 pub total_vram_bytes: u64,
109 pub warnings: Vec<String>,
110 pub notes: Vec<String>,
111}
112
113#[derive(Clone, Copy, Debug, PartialEq, Eq)]
114enum TuneBackend {
115 Cpu,
116 Cuda,
117 Metal,
118}
119
120#[derive(Clone, Copy, Debug)]
121enum TuneKind {
122 Normal,
123 Vision,
124 Embedding,
125}
126
127fn backend_from_devices(devices: &[Device]) -> TuneBackend {
128 if devices.iter().any(|d| matches!(d, Device::Cuda(_))) {
129 TuneBackend::Cuda
130 } else if devices.iter().any(|d| matches!(d, Device::Metal(_))) {
131 TuneBackend::Metal
132 } else {
133 TuneBackend::Cpu
134 }
135}
136
137fn backend_name(backend: TuneBackend) -> String {
138 match backend {
139 TuneBackend::Cpu => "cpu".to_string(),
140 TuneBackend::Cuda => "cuda".to_string(),
141 TuneBackend::Metal => "metal".to_string(),
142 }
143}
144
145fn select_devices(force_cpu: bool) -> Result<Vec<Device>> {
146 if force_cpu {
147 return Ok(vec![Device::Cpu]);
148 }
149
150 #[cfg(all(feature = "cuda", target_family = "unix"))]
151 {
152 if let Ok(dev) = Device::new_cuda(0) {
153 return Ok(crate::device_map::get_all_similar_devices(&dev)?);
154 }
155 }
156
157 #[cfg(feature = "metal")]
158 {
159 if let Ok(dev) = Device::new_metal(0) {
160 return Ok(crate::device_map::get_all_similar_devices(&dev)?);
161 }
162 }
163
164 Ok(vec![Device::Cpu])
165}
166
167fn hf_cache_path_from_model(model: &ModelSelected) -> Option<PathBuf> {
168 match model {
169 ModelSelected::Plain { hf_cache_path, .. }
170 | ModelSelected::Lora { hf_cache_path, .. }
171 | ModelSelected::XLora { hf_cache_path, .. }
172 | ModelSelected::VisionPlain { hf_cache_path, .. }
173 | ModelSelected::Embedding { hf_cache_path, .. }
174 | ModelSelected::Run { hf_cache_path, .. } => hf_cache_path.clone(),
175 _ => None,
176 }
177}
178
179fn model_id_from_selected(model: &ModelSelected) -> String {
180 match model {
181 ModelSelected::Plain { model_id, .. }
182 | ModelSelected::Lora {
183 model_id: Some(model_id),
184 ..
185 }
186 | ModelSelected::XLora {
187 model_id: Some(model_id),
188 ..
189 }
190 | ModelSelected::VisionPlain { model_id, .. }
191 | ModelSelected::Embedding { model_id, .. }
192 | ModelSelected::Run { model_id, .. } => model_id.clone(),
193 ModelSelected::GGUF {
194 quantized_model_id, ..
195 }
196 | ModelSelected::GGML {
197 quantized_model_id, ..
198 }
199 | ModelSelected::LoraGGUF {
200 quantized_model_id, ..
201 }
202 | ModelSelected::XLoraGGUF {
203 quantized_model_id, ..
204 }
205 | ModelSelected::LoraGGML {
206 quantized_model_id, ..
207 }
208 | ModelSelected::XLoraGGML {
209 quantized_model_id, ..
210 } => quantized_model_id.clone(),
211 ModelSelected::DiffusionPlain { model_id, .. } => model_id.clone(),
212 ModelSelected::Speech { model_id, .. } => model_id.clone(),
213 ModelSelected::Toml { file } => file.clone(),
214 ModelSelected::MultiModel { .. } => "multi-model".to_string(),
215 _ => "unknown".to_string(),
216 }
217}
218
219fn load_config_artifacts(
220 model_id: &str,
221 token_source: &TokenSource,
222 hf_revision: Option<String>,
223 hf_cache_path: Option<PathBuf>,
224) -> Result<(String, bool)> {
225 if Path::new(model_id).exists() {
226 let config_path = Path::new(model_id).join("config.json");
227 let config = std::fs::read_to_string(&config_path)
228 .with_context(|| format!("Failed to read config.json at {}", config_path.display()))?;
229 let sentence_transformers = Path::new(model_id)
230 .join("config_sentence_transformers.json")
231 .exists();
232 return Ok((config, sentence_transformers));
233 }
234
235 let cache = hf_cache_path
236 .map(Cache::new)
237 .unwrap_or_else(Cache::from_env);
238 GLOBAL_HF_CACHE.get_or_init(|| cache.clone());
239
240 let mut api = ApiBuilder::from_cache(cache)
241 .with_progress(false)
242 .with_token(get_token(token_source)?);
243 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
244 api = api.with_cache_dir(cache_dir);
245 }
246 let api = api.build()?;
247 let revision = hf_revision.unwrap_or_else(|| "main".to_string());
248 let api = api.repo(Repo::with_revision(
249 model_id.to_string(),
250 RepoType::Model,
251 revision,
252 ));
253
254 let config_path = api_get_file(&api, model_id, "config.json")?;
255 let config = std::fs::read_to_string(&config_path).with_context(|| {
256 format!(
257 "Failed to read config.json from cache at {}",
258 config_path.display()
259 )
260 })?;
261
262 let sentence_transformers =
263 api_get_file(&api, model_id, "config_sentence_transformers.json").is_ok();
264
265 Ok((config, sentence_transformers))
266}
267
268fn api_get_file(api: &ApiRepo, model_id: &str, file: &str) -> Result<PathBuf> {
269 let model_id = Path::new(model_id);
270 if model_id.exists() {
271 let path = model_id.join(file);
272 if path.exists() {
273 Ok(path)
274 } else {
275 anyhow::bail!("File {file} not found at {}", model_id.display())
276 }
277 } else {
278 Ok(api.get(file)?)
279 }
280}
281
282fn infer_kind(config: &str, sentence_transformers: bool) -> Result<TuneKind> {
283 if sentence_transformers {
284 return Ok(TuneKind::Embedding);
285 }
286 #[derive(Deserialize)]
287 struct AutoConfig {
288 architectures: Vec<String>,
289 }
290 let cfg: AutoConfig = serde_json::from_str(config)?;
291 if cfg.architectures.len() != 1 {
292 anyhow::bail!("Expected exactly one architecture in config");
293 }
294 let name = &cfg.architectures[0];
295 if VisionLoaderType::from_causal_lm_name(name).is_ok() {
296 return Ok(TuneKind::Vision);
297 }
298 if EmbeddingLoaderType::from_causal_lm_name(name).is_ok() {
299 return Ok(TuneKind::Embedding);
300 }
301 let _ = NormalLoaderType::from_causal_lm_name(name)?;
302 Ok(TuneKind::Normal)
303}
304
305fn all_candidates(backend: TuneBackend) -> Vec<Option<IsqType>> {
307 match backend {
308 TuneBackend::Metal => vec![
309 None, Some(IsqType::AFQ8),
311 Some(IsqType::AFQ6),
312 Some(IsqType::AFQ4),
313 Some(IsqType::AFQ3),
314 Some(IsqType::AFQ2),
315 ],
316 _ => vec![
317 None, Some(IsqType::Q8_0),
319 Some(IsqType::Q6K),
320 Some(IsqType::Q5K),
321 Some(IsqType::Q4K),
322 Some(IsqType::Q3K),
323 Some(IsqType::Q2K),
324 ],
325 }
326}
327
328fn default_candidates(profile: TuneProfile, backend: TuneBackend) -> Vec<IsqType> {
329 match backend {
330 TuneBackend::Metal => match profile {
331 TuneProfile::Quality => {
332 vec![IsqType::AFQ8, IsqType::AFQ6, IsqType::AFQ4, IsqType::AFQ3]
333 }
334 TuneProfile::Balanced => vec![IsqType::AFQ6, IsqType::AFQ4, IsqType::AFQ3],
335 TuneProfile::Fast => vec![IsqType::AFQ4, IsqType::AFQ3, IsqType::AFQ2],
336 },
337 _ => match profile {
338 TuneProfile::Quality => vec![
339 IsqType::Q8_0,
340 IsqType::Q6K,
341 IsqType::Q5K,
342 IsqType::Q4K,
343 IsqType::Q3K,
344 IsqType::Q2K,
345 ],
346 TuneProfile::Balanced => vec![IsqType::Q6K, IsqType::Q5K, IsqType::Q4K, IsqType::Q3K],
347 TuneProfile::Fast => vec![IsqType::Q4K, IsqType::Q3K, IsqType::Q2K],
348 },
349 }
350}
351
352fn quality_tier(isq: Option<IsqType>) -> QualityTier {
354 match isq {
355 None => QualityTier::Baseline,
356 Some(t) => match t {
357 IsqType::Q8_0 | IsqType::Q8_1 | IsqType::Q8K | IsqType::AFQ8 | IsqType::HQQ8 => {
358 QualityTier::NearLossless
359 }
360 IsqType::Q6K | IsqType::AFQ6 => QualityTier::Good,
361 IsqType::Q5_0 | IsqType::Q5_1 | IsqType::Q5K => QualityTier::Good,
362 IsqType::Q4_0 | IsqType::Q4_1 | IsqType::Q4K | IsqType::AFQ4 | IsqType::HQQ4 => {
363 QualityTier::Acceptable
364 }
365 IsqType::Q3K | IsqType::AFQ3 => QualityTier::Degraded,
366 IsqType::Q2K | IsqType::AFQ2 => QualityTier::Degraded,
367 _ => QualityTier::Acceptable,
368 },
369 }
370}
371
372fn isq_display_name(isq: Option<IsqType>) -> String {
374 match isq {
375 None => "None (FP16)".to_string(),
376 Some(t) => format!("{t:?}"),
377 }
378}
379
380#[allow(clippy::cast_possible_truncation)]
382fn total_vram(devices: &[Device]) -> u64 {
383 use crate::MemoryUsage;
384 devices
385 .iter()
386 .filter(|d| !matches!(d, Device::Cpu))
387 .filter_map(|d| MemoryUsage.get_total_memory(d).ok())
388 .sum::<usize>() as u64
389}
390
391#[allow(clippy::cast_possible_truncation)]
393fn available_vram(devices: &[Device]) -> u64 {
394 use crate::MemoryUsage;
395 devices
396 .iter()
397 .filter(|d| !matches!(d, Device::Cpu))
398 .filter_map(|d| MemoryUsage.get_memory_available(d).ok())
399 .sum::<usize>() as u64
400}
401
402#[allow(clippy::cast_possible_truncation)]
407fn calculate_max_context(
408 loader: &dyn DeviceMappedModelLoader,
409 config: &str,
410 model_size_bytes: u64,
411 available_vram_bytes: u64,
412 dtype: DType,
413) -> Result<(usize, bool)> {
414 let model_cfg = loader.model_config(config)?;
415 let native_max_seq_len = model_cfg.max_seq_len();
416
417 if model_size_bytes >= available_vram_bytes {
418 return Ok((0, false));
419 }
420
421 let remaining_bytes = available_vram_bytes - model_size_bytes;
422
423 let kv_elems_per_token = model_cfg.kv_cache_elements_per_token();
426 let num_layers = model_cfg.num_layers();
427
428 let dtype_size = dtype.size_in_bytes();
430 let kv_bytes_per_token = kv_elems_per_token * dtype_size * num_layers;
431
432 if kv_bytes_per_token == 0 {
433 return Ok((native_max_seq_len, true));
434 }
435
436 let calculated_max = remaining_bytes as usize / kv_bytes_per_token;
437
438 let is_at_model_max = calculated_max >= native_max_seq_len;
440 Ok((calculated_max.min(native_max_seq_len), is_at_model_max))
441}
442
443fn map_for_candidate(
444 loader: &dyn DeviceMappedModelLoader,
445 config: &str,
446 dtype: DType,
447 params: &AutoDeviceMapParams,
448 devices: &[Device],
449 isq: Option<IsqType>,
450) -> Result<(DeviceMapMetadata, usize)> {
451 let pack_factor = isq.map(|i| i.pack_factor(dtype)).unwrap_or(1);
452 let layer_sizes = loader.layer_sizes_in_bytes(config, dtype, pack_factor, None)?;
453 let non_mapped = loader.non_mapped_size_in_bytes(config, dtype, pack_factor, None)?;
454 let total = layer_sizes.iter().sum::<usize>() + non_mapped;
455 let map = crate::pipeline::get_device_layers_for_loader(
456 loader,
457 config,
458 loader.num_layers(config)?,
459 layer_sizes,
460 non_mapped,
461 total,
462 devices,
463 dtype,
464 params,
465 None,
466 )?;
467 Ok((map, total))
468}
469
470#[allow(
471 clippy::cast_precision_loss,
472 clippy::cast_possible_truncation,
473 clippy::cast_sign_loss
474)]
475pub fn auto_tune(req: AutoTuneRequest) -> Result<AutoTuneResult> {
476 let model_id = model_id_from_selected(&req.model);
477 match &req.model {
478 ModelSelected::GGUF { .. }
479 | ModelSelected::GGML { .. }
480 | ModelSelected::LoraGGUF { .. }
481 | ModelSelected::XLoraGGUF { .. }
482 | ModelSelected::LoraGGML { .. }
483 | ModelSelected::XLoraGGML { .. } => {
484 anyhow::bail!("Auto-tuning is not supported for pre-quantized GGUF/GGML models.");
485 }
486 ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
487 anyhow::bail!("Auto-tuning is not supported for diffusion or speech models.");
488 }
489 _ => {}
490 }
491
492 let hf_cache_path = hf_cache_path_from_model(&req.model);
493 let (config, sentence_transformers) = load_config_artifacts(
494 &model_id,
495 &req.token_source,
496 req.hf_revision.clone(),
497 hf_cache_path,
498 )?;
499
500 let kind = match &req.model {
501 ModelSelected::VisionPlain { .. } => TuneKind::Vision,
502 ModelSelected::Embedding { .. } => TuneKind::Embedding,
503 _ => infer_kind(&config, sentence_transformers)?,
504 };
505
506 let mut params = get_auto_device_map_params(&req.model)?;
507 if matches!(kind, TuneKind::Vision) {
508 params = params.maybe_promote_to_vision();
509 }
510
511 let devices = select_devices(req.force_cpu)?;
512 let backend = backend_from_devices(&devices);
513
514 let dtype = {
515 let model_dtype = get_model_dtype(&req.model)?;
516 let refs = devices.iter().collect::<Vec<_>>();
517 model_dtype.try_into_dtype(&refs)?
518 };
519
520 let loader_normal = AutoNormalLoader;
521 let loader_vision = AutoVisionLoader;
522 let loader_embedding = AutoEmbeddingLoader;
523 let loader: &dyn DeviceMappedModelLoader = match kind {
524 TuneKind::Normal => &loader_normal,
525 TuneKind::Vision => &loader_vision,
526 TuneKind::Embedding => &loader_embedding,
527 };
528
529 let preferred_candidates: Vec<Option<IsqType>> =
531 req.requested_isq.map(|t| vec![Some(t)]).unwrap_or_else(|| {
532 default_candidates(req.profile, backend)
533 .into_iter()
534 .map(Some)
535 .collect()
536 });
537
538 let mut warnings = Vec::new();
539 let mut notes = Vec::new();
540
541 if matches!(kind, TuneKind::Embedding) {
542 notes.push("Detected embedding model configuration.".to_string());
543 }
544 if matches!(kind, TuneKind::Vision) {
545 notes.push("Detected vision model configuration.".to_string());
546 }
547
548 let total_vram_bytes = total_vram(&devices);
550 let avail_vram_bytes = available_vram(&devices);
551
552 let all_isq_candidates = all_candidates(backend);
554 let mut tune_candidates: Vec<TuneCandidate> = Vec::new();
555 let mut recommended_idx: Option<usize> = None;
556
557 for isq in all_isq_candidates {
558 let result = map_for_candidate(loader, &config, dtype, ¶ms, &devices, isq);
559
560 let (fit_status, estimated_size, device_layers_cli) = match &result {
561 Ok((map, total_size)) => {
562 let layers = map.device_layers();
563 let is_hybrid = layers
564 .map(|l| l.iter().any(|d| d.ordinal == usize::MAX))
565 .unwrap_or(false);
566 let status = if is_hybrid {
567 FitStatus::Hybrid
568 } else {
569 FitStatus::Fits
570 };
571 (status, *total_size as u64, map.to_cli_spec())
572 }
573 Err(_) => {
574 let pack_factor = isq.map(|i| i.pack_factor(dtype)).unwrap_or(1);
576 let layer_sizes = loader
577 .layer_sizes_in_bytes(&config, dtype, pack_factor, None)
578 .unwrap_or_default();
579 let non_mapped = loader
580 .non_mapped_size_in_bytes(&config, dtype, pack_factor, None)
581 .unwrap_or(0);
582 let est_size = (layer_sizes.iter().sum::<usize>() + non_mapped) as u64;
583 (FitStatus::TooLarge, est_size, None)
584 }
585 };
586
587 let vram_usage = if total_vram_bytes > 0 {
588 (estimated_size as f32) / (total_vram_bytes as f32)
589 } else {
590 1.0
591 };
592
593 let (context_room, context_is_model_max) =
594 calculate_max_context(loader, &config, estimated_size, avail_vram_bytes, dtype)
595 .unwrap_or((0, false));
596
597 let candidate = TuneCandidate {
598 isq,
599 isq_name: isq_display_name(isq),
600 estimated_size_bytes: estimated_size,
601 vram_usage_percent: vram_usage,
602 max_context_tokens: context_room,
603 context_is_model_max,
604 quality: quality_tier(isq),
605 fit_status,
606 device_layers_cli,
607 recommended: false, };
609
610 tune_candidates.push(candidate);
611 }
612
613 for pref in &preferred_candidates {
615 if let Some(idx) = tune_candidates.iter().position(|c| {
616 c.isq == *pref && matches!(c.fit_status, FitStatus::Fits | FitStatus::Hybrid)
617 }) {
618 tune_candidates[idx].recommended = true;
619 recommended_idx = Some(idx);
620 break;
621 }
622 }
623
624 if recommended_idx.is_none() {
626 if let Some(idx) = tune_candidates
627 .iter()
628 .position(|c| matches!(c.fit_status, FitStatus::Fits | FitStatus::Hybrid))
629 {
630 tune_candidates[idx].recommended = true;
631 recommended_idx = Some(idx);
632 }
633 }
634
635 let (recommended_isq, device_layers, device_layers_cli, recommended_command) =
637 if let Some(idx) = recommended_idx {
638 let rec = &tune_candidates[idx];
639 let isq_flag = rec
640 .isq
641 .map(|i| format!(" --isq {:?}", i).to_lowercase())
642 .unwrap_or_default();
643 let cmd = format!("mistralrs serve -m {model_id}{isq_flag}");
644 (rec.isq, None, rec.device_layers_cli.clone(), cmd)
645 } else {
646 (None, None, None, format!("mistralrs serve -m {model_id}"))
647 };
648
649 let paged_attn_mode = if backend != TuneBackend::Cpu && paged_attn_supported() {
650 Some("auto".to_string())
651 } else {
652 Some("off".to_string())
653 };
654
655 for c in &tune_candidates {
657 if matches!(c.fit_status, FitStatus::TooLarge) && c.isq.is_some() {
658 warnings.push(format!(
659 "{} ({:.1} GB) exceeds available VRAM",
660 c.isq_name,
661 c.estimated_size_bytes as f64 / 1e9
662 ));
663 }
664 }
665
666 if recommended_idx.is_none() {
667 anyhow::bail!(
668 "No suitable quantization level fits on the available devices. Try a smaller model or enable CPU offload."
669 );
670 }
671
672 Ok(AutoTuneResult {
673 model_id,
674 profile: req.profile,
675 backend: backend_name(backend),
676 candidates: tune_candidates,
677 recommended_isq,
678 device_layers,
679 device_layers_cli,
680 paged_attn_mode,
681 recommended_command,
682 total_vram_bytes,
683 warnings,
684 notes,
685 })
686}