1use std::path::{Path, PathBuf};
2
3use anyhow::{Context, Result};
4use hanzo_ml::{DType, Device};
5use hf_hub::{
6 api::sync::{ApiBuilder, ApiRepo},
7 Cache, Repo, RepoType,
8};
9use serde::{Deserialize, Serialize};
10
11use crate::device_map::{DeviceLayerMapMetadata, DeviceMapMetadata};
12use crate::model_loader::{get_auto_device_map_params, get_model_dtype};
13use crate::pipeline::{
14 AutoDeviceMapParams, AutoEmbeddingLoader, AutoMultimodalLoader, AutoNormalLoader,
15 DeviceMappedModelLoader, EmbeddingLoaderType, MultimodalLoaderType, NormalLoaderType,
16 TokenSource,
17};
18use crate::utils::tokens::get_token;
19use crate::{paged_attn_supported, IsqType, ModelSelected, TryIntoDType, GLOBAL_HF_CACHE};
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
22#[serde(rename_all = "kebab-case")]
23pub enum TuneProfile {
24 Quality,
25 Balanced,
26 Fast,
27}
28
29#[derive(Debug, Clone)]
30pub struct AutoTuneRequest {
31 pub model: ModelSelected,
32 pub token_source: TokenSource,
33 pub hf_revision: Option<String>,
34 pub force_cpu: bool,
35 pub profile: TuneProfile,
36 pub requested_isq: Option<IsqType>,
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
41#[serde(rename_all = "kebab-case")]
42pub enum QualityTier {
43 Baseline,
45 NearLossless,
47 Good,
49 Acceptable,
51 Degraded,
53}
54
55#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
57#[serde(rename_all = "kebab-case")]
58pub enum FitStatus {
59 Fits,
61 Hybrid,
63 TooLarge,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct TuneCandidate {
70 pub isq: Option<IsqType>,
72 pub isq_name: String,
74 pub estimated_size_bytes: u64,
76 pub vram_usage_percent: f32,
78 pub max_context_tokens: usize,
80 pub context_is_model_max: bool,
82 pub quality: QualityTier,
84 pub fit_status: FitStatus,
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub device_layers_cli: Option<String>,
89 pub recommended: bool,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct AutoTuneResult {
95 pub model_id: String,
96 pub profile: TuneProfile,
97 pub backend: String,
98 pub candidates: Vec<TuneCandidate>,
100 pub recommended_isq: Option<IsqType>,
102 pub device_layers: Option<Vec<DeviceLayerMapMetadata>>,
104 pub device_layers_cli: Option<String>,
105 pub paged_attn_mode: Option<String>,
106 pub recommended_command: String,
108 pub total_vram_bytes: u64,
110 pub warnings: Vec<String>,
111 pub notes: Vec<String>,
112}
113
114#[derive(Clone, Copy, Debug, PartialEq, Eq)]
115enum TuneBackend {
116 Cpu,
117 Cuda,
118 Metal,
119}
120
121#[derive(Clone, Copy, Debug)]
122enum TuneKind {
123 Normal,
124 Multimodal,
125 Embedding,
126}
127
128fn backend_from_devices(devices: &[Device]) -> TuneBackend {
129 if devices.iter().any(|d| matches!(d, Device::Cuda(_))) {
130 TuneBackend::Cuda
131 } else if devices.iter().any(|d| matches!(d, Device::Metal(_))) {
132 TuneBackend::Metal
133 } else {
134 TuneBackend::Cpu
135 }
136}
137
138fn backend_name(backend: TuneBackend) -> String {
139 match backend {
140 TuneBackend::Cpu => "cpu".to_string(),
141 TuneBackend::Cuda => "cuda".to_string(),
142 TuneBackend::Metal => "metal".to_string(),
143 }
144}
145
146fn select_devices(force_cpu: bool) -> Result<Vec<Device>> {
147 if force_cpu {
148 return Ok(vec![Device::Cpu]);
149 }
150
151 #[cfg(all(feature = "cuda", target_family = "unix"))]
152 {
153 if let Ok(dev) = Device::new_cuda(0) {
154 return Ok(crate::device_map::get_all_similar_devices(&dev)?);
155 }
156 }
157
158 #[cfg(feature = "metal")]
159 {
160 if let Ok(dev) = Device::new_metal(0) {
161 return Ok(crate::device_map::get_all_similar_devices(&dev)?);
162 }
163 }
164
165 Ok(vec![Device::Cpu])
166}
167
168fn hf_cache_path_from_model(model: &ModelSelected) -> Option<PathBuf> {
169 match model {
170 ModelSelected::Plain { hf_cache_path, .. }
171 | ModelSelected::Lora { hf_cache_path, .. }
172 | ModelSelected::XLora { hf_cache_path, .. }
173 | ModelSelected::MultimodalPlain { hf_cache_path, .. }
174 | ModelSelected::Embedding { hf_cache_path, .. }
175 | ModelSelected::Run { hf_cache_path, .. } => hf_cache_path.clone(),
176 _ => None,
177 }
178}
179
180fn model_id_from_selected(model: &ModelSelected) -> String {
181 match model {
182 ModelSelected::Plain { model_id, .. }
183 | ModelSelected::Lora {
184 model_id: Some(model_id),
185 ..
186 }
187 | ModelSelected::XLora {
188 model_id: Some(model_id),
189 ..
190 }
191 | ModelSelected::MultimodalPlain { model_id, .. }
192 | ModelSelected::Embedding { model_id, .. }
193 | ModelSelected::Run { model_id, .. } => model_id.clone(),
194 ModelSelected::GGUF {
195 quantized_model_id, ..
196 }
197 | ModelSelected::GGML {
198 quantized_model_id, ..
199 }
200 | ModelSelected::LoraGGUF {
201 quantized_model_id, ..
202 }
203 | ModelSelected::XLoraGGUF {
204 quantized_model_id, ..
205 }
206 | ModelSelected::LoraGGML {
207 quantized_model_id, ..
208 }
209 | ModelSelected::XLoraGGML {
210 quantized_model_id, ..
211 } => quantized_model_id.clone(),
212 ModelSelected::DiffusionPlain { model_id, .. } => model_id.clone(),
213 ModelSelected::Speech { model_id, .. } => model_id.clone(),
214 ModelSelected::Toml { file } => file.clone(),
215 ModelSelected::MultiModel { .. } => "multi-model".to_string(),
216 _ => "unknown".to_string(),
217 }
218}
219
220fn load_config_artifacts(
221 model_id: &str,
222 token_source: &TokenSource,
223 hf_revision: Option<String>,
224 hf_cache_path: Option<PathBuf>,
225) -> Result<(String, bool)> {
226 if Path::new(model_id).exists() {
227 let config_path = Path::new(model_id).join("config.json");
228 let config = std::fs::read_to_string(&config_path)
229 .with_context(|| format!("Failed to read config.json at {}", config_path.display()))?;
230 let sentence_transformers = Path::new(model_id)
231 .join("config_sentence_transformers.json")
232 .exists();
233 return Ok((config, sentence_transformers));
234 }
235
236 let cache = hf_cache_path
237 .map(Cache::new)
238 .unwrap_or_else(Cache::from_env);
239 GLOBAL_HF_CACHE.get_or_init(|| cache.clone());
240
241 let mut api = ApiBuilder::from_cache(cache)
242 .with_progress(false)
243 .with_token(get_token(token_source)?);
244 if let Some(cache_dir) = crate::hf_hub_cache_dir() {
245 api = api.with_cache_dir(cache_dir);
246 }
247 let api = api.build()?;
248 let revision = hf_revision.unwrap_or_else(|| "main".to_string());
249 let api = api.repo(Repo::with_revision(
250 model_id.to_string(),
251 RepoType::Model,
252 revision,
253 ));
254
255 let config_path = api_get_file(&api, model_id, "config.json")?;
256 let config = std::fs::read_to_string(&config_path).with_context(|| {
257 format!(
258 "Failed to read config.json from cache at {}",
259 config_path.display()
260 )
261 })?;
262
263 let sentence_transformers =
264 api_get_file(&api, model_id, "config_sentence_transformers.json").is_ok();
265
266 Ok((config, sentence_transformers))
267}
268
269fn api_get_file(api: &ApiRepo, model_id: &str, file: &str) -> Result<PathBuf> {
270 let model_id = Path::new(model_id);
271 if model_id.exists() {
272 let path = model_id.join(file);
273 if path.exists() {
274 Ok(path)
275 } else {
276 anyhow::bail!("File {file} not found at {}", model_id.display())
277 }
278 } else {
279 Ok(api.get(file)?)
280 }
281}
282
283fn infer_kind(config: &str, sentence_transformers: bool) -> Result<TuneKind> {
284 if sentence_transformers {
285 return Ok(TuneKind::Embedding);
286 }
287 #[derive(Deserialize)]
288 struct AutoConfig {
289 architectures: Vec<String>,
290 }
291 let cfg: AutoConfig = serde_json::from_str(config)?;
292 if cfg.architectures.len() != 1 {
293 anyhow::bail!("Expected exactly one architecture in config");
294 }
295 let name = &cfg.architectures[0];
296 if MultimodalLoaderType::from_causal_lm_name(name).is_ok() {
297 return Ok(TuneKind::Multimodal);
298 }
299 if EmbeddingLoaderType::from_causal_lm_name(name).is_ok() {
300 return Ok(TuneKind::Embedding);
301 }
302 let _ = NormalLoaderType::from_causal_lm_name(name)?;
303 Ok(TuneKind::Normal)
304}
305
306fn all_candidates(backend: TuneBackend) -> Vec<Option<IsqType>> {
308 match backend {
309 TuneBackend::Metal => vec![
310 None, Some(IsqType::AFQ8),
312 Some(IsqType::AFQ6),
313 Some(IsqType::AFQ4),
314 Some(IsqType::AFQ3),
315 Some(IsqType::AFQ2),
316 ],
317 _ => vec![
318 None, Some(IsqType::Q8_0),
320 Some(IsqType::Q6K),
321 Some(IsqType::Q5K),
322 Some(IsqType::Q4K),
323 Some(IsqType::Q3K),
324 Some(IsqType::Q2K),
325 ],
326 }
327}
328
329fn default_candidates(profile: TuneProfile, backend: TuneBackend) -> Vec<IsqType> {
330 match backend {
331 TuneBackend::Metal => match profile {
332 TuneProfile::Quality => {
333 vec![IsqType::AFQ8, IsqType::AFQ6, IsqType::AFQ4, IsqType::AFQ3]
334 }
335 TuneProfile::Balanced => vec![IsqType::AFQ6, IsqType::AFQ4, IsqType::AFQ3],
336 TuneProfile::Fast => vec![IsqType::AFQ4, IsqType::AFQ3, IsqType::AFQ2],
337 },
338 _ => match profile {
339 TuneProfile::Quality => vec![
340 IsqType::Q8_0,
341 IsqType::Q6K,
342 IsqType::Q5K,
343 IsqType::Q4K,
344 IsqType::Q3K,
345 IsqType::Q2K,
346 ],
347 TuneProfile::Balanced => vec![IsqType::Q6K, IsqType::Q5K, IsqType::Q4K, IsqType::Q3K],
348 TuneProfile::Fast => vec![IsqType::Q4K, IsqType::Q3K, IsqType::Q2K],
349 },
350 }
351}
352
353fn quality_tier(isq: Option<IsqType>) -> QualityTier {
355 match isq {
356 None => QualityTier::Baseline,
357 Some(t) => match t {
358 IsqType::Q8_0 | IsqType::Q8_1 | IsqType::Q8K | IsqType::AFQ8 | IsqType::HQQ8 => {
359 QualityTier::NearLossless
360 }
361 IsqType::Q6K | IsqType::AFQ6 => QualityTier::Good,
362 IsqType::Q5_0 | IsqType::Q5_1 | IsqType::Q5K => QualityTier::Good,
363 IsqType::Q4_0 | IsqType::Q4_1 | IsqType::Q4K | IsqType::AFQ4 | IsqType::HQQ4 => {
364 QualityTier::Acceptable
365 }
366 IsqType::Q3K | IsqType::AFQ3 => QualityTier::Degraded,
367 IsqType::Q2K | IsqType::AFQ2 => QualityTier::Degraded,
368 _ => QualityTier::Acceptable,
369 },
370 }
371}
372
373fn isq_display_name(isq: Option<IsqType>) -> String {
375 match isq {
376 None => "None (FP16)".to_string(),
377 Some(t) => format!("{t:?}"),
378 }
379}
380
381#[allow(clippy::cast_possible_truncation)]
383fn total_vram(devices: &[Device]) -> u64 {
384 use crate::MemoryUsage;
385 devices
386 .iter()
387 .filter(|d| !matches!(d, Device::Cpu))
388 .filter_map(|d| MemoryUsage.query(d).ok().map(|m| m.total()))
389 .sum::<usize>() as u64
390}
391
392#[allow(clippy::cast_possible_truncation)]
394fn available_vram(devices: &[Device]) -> u64 {
395 use crate::MemoryUsage;
396 devices
397 .iter()
398 .filter(|d| !matches!(d, Device::Cpu))
399 .filter_map(|d| MemoryUsage.query(d).ok().map(|m| m.available()))
400 .sum::<usize>() as u64
401}
402
403#[allow(clippy::cast_possible_truncation)]
408fn calculate_max_context(
409 loader: &dyn DeviceMappedModelLoader,
410 config: &str,
411 model_size_bytes: u64,
412 available_vram_bytes: u64,
413 dtype: DType,
414) -> Result<(usize, bool)> {
415 let model_cfg = loader.model_config(config)?;
416 let native_max_seq_len = model_cfg.max_seq_len();
417
418 if model_size_bytes >= available_vram_bytes {
419 return Ok((0, false));
420 }
421
422 let remaining_bytes = available_vram_bytes - model_size_bytes;
423
424 let kv_elems_per_token = model_cfg.kv_cache_elements_per_token();
427 let num_layers = model_cfg.num_layers();
428
429 let dtype_size = dtype.size_in_bytes();
431 let kv_bytes_per_token = kv_elems_per_token * dtype_size * num_layers;
432
433 if kv_bytes_per_token == 0 {
434 return Ok((native_max_seq_len, true));
435 }
436
437 let calculated_max = remaining_bytes as usize / kv_bytes_per_token;
438
439 let is_at_model_max = calculated_max >= native_max_seq_len;
441 Ok((calculated_max.min(native_max_seq_len), is_at_model_max))
442}
443
444fn map_for_candidate(
445 loader: &dyn DeviceMappedModelLoader,
446 config: &str,
447 dtype: DType,
448 params: &AutoDeviceMapParams,
449 devices: &[Device],
450 isq: Option<IsqType>,
451) -> Result<(DeviceMapMetadata, usize)> {
452 let pack_factor = isq.map(|i| i.pack_factor(dtype)).unwrap_or(1);
453 let layer_sizes = loader.layer_sizes_in_bytes(config, dtype, pack_factor, None)?;
454 let non_mapped = loader.non_mapped_size_in_bytes(config, dtype, pack_factor, None)?;
455 let total = layer_sizes.iter().sum::<usize>() + non_mapped;
456 let map = crate::pipeline::get_device_layers_for_loader(
457 loader,
458 config,
459 loader.num_layers(config)?,
460 layer_sizes,
461 non_mapped,
462 total,
463 devices,
464 dtype,
465 params,
466 None,
467 )?;
468 Ok((map, total))
469}
470
471#[allow(
472 clippy::cast_precision_loss,
473 clippy::cast_possible_truncation,
474 clippy::cast_sign_loss
475)]
476pub fn auto_tune(req: AutoTuneRequest) -> Result<AutoTuneResult> {
477 let model_id = model_id_from_selected(&req.model);
478 match &req.model {
479 ModelSelected::GGUF { .. }
480 | ModelSelected::GGML { .. }
481 | ModelSelected::LoraGGUF { .. }
482 | ModelSelected::XLoraGGUF { .. }
483 | ModelSelected::LoraGGML { .. }
484 | ModelSelected::XLoraGGML { .. } => {
485 anyhow::bail!("Auto-tuning is not supported for pre-quantized GGUF/GGML models.");
486 }
487 ModelSelected::DiffusionPlain { .. } | ModelSelected::Speech { .. } => {
488 anyhow::bail!("Auto-tuning is not supported for diffusion or speech models.");
489 }
490 _ => {}
491 }
492
493 let hf_cache_path = hf_cache_path_from_model(&req.model);
494 let (config, sentence_transformers) = load_config_artifacts(
495 &model_id,
496 &req.token_source,
497 req.hf_revision.clone(),
498 hf_cache_path,
499 )?;
500
501 let kind = match &req.model {
502 ModelSelected::MultimodalPlain { .. } => TuneKind::Multimodal,
503 ModelSelected::Embedding { .. } => TuneKind::Embedding,
504 _ => infer_kind(&config, sentence_transformers)?,
505 };
506
507 let mut params = get_auto_device_map_params(&req.model)?;
508 if matches!(kind, TuneKind::Multimodal) {
509 params = params.maybe_promote_to_multimodal();
510 }
511
512 let devices = select_devices(req.force_cpu)?;
513 let backend = backend_from_devices(&devices);
514
515 let dtype = {
516 let model_dtype = get_model_dtype(&req.model)?;
517 let refs = devices.iter().collect::<Vec<_>>();
518 model_dtype.try_into_dtype(&refs)?
519 };
520
521 let loader_normal = AutoNormalLoader;
522 let loader_multimodal = AutoMultimodalLoader;
523 let loader_embedding = AutoEmbeddingLoader;
524 let loader: &dyn DeviceMappedModelLoader = match kind {
525 TuneKind::Normal => &loader_normal,
526 TuneKind::Multimodal => &loader_multimodal,
527 TuneKind::Embedding => &loader_embedding,
528 };
529
530 let preferred_candidates: Vec<Option<IsqType>> =
532 req.requested_isq.map(|t| vec![Some(t)]).unwrap_or_else(|| {
533 default_candidates(req.profile, backend)
534 .into_iter()
535 .map(Some)
536 .collect()
537 });
538
539 let mut warnings = Vec::new();
540 let mut notes = Vec::new();
541
542 if matches!(kind, TuneKind::Embedding) {
543 notes.push("Detected embedding model configuration.".to_string());
544 }
545 if matches!(kind, TuneKind::Multimodal) {
546 notes.push("Detected multimodal model configuration.".to_string());
547 }
548
549 let total_vram_bytes = total_vram(&devices);
551 let avail_vram_bytes = available_vram(&devices);
552
553 let all_isq_candidates = all_candidates(backend);
555 let mut tune_candidates: Vec<TuneCandidate> = Vec::new();
556 let mut recommended_idx: Option<usize> = None;
557
558 for isq in all_isq_candidates {
559 let result = map_for_candidate(loader, &config, dtype, ¶ms, &devices, isq);
560
561 let (fit_status, estimated_size, device_layers_cli) = match &result {
562 Ok((map, total_size)) => {
563 let layers = map.device_layers();
564 let is_hybrid = layers
565 .map(|l| l.iter().any(|d| d.ordinal == usize::MAX))
566 .unwrap_or(false);
567 let status = if is_hybrid {
568 FitStatus::Hybrid
569 } else {
570 FitStatus::Fits
571 };
572 (status, *total_size as u64, map.to_cli_spec())
573 }
574 Err(_) => {
575 let pack_factor = isq.map(|i| i.pack_factor(dtype)).unwrap_or(1);
577 let layer_sizes = loader
578 .layer_sizes_in_bytes(&config, dtype, pack_factor, None)
579 .unwrap_or_default();
580 let non_mapped = loader
581 .non_mapped_size_in_bytes(&config, dtype, pack_factor, None)
582 .unwrap_or(0);
583 let est_size = (layer_sizes.iter().sum::<usize>() + non_mapped) as u64;
584 (FitStatus::TooLarge, est_size, None)
585 }
586 };
587
588 let vram_usage = if total_vram_bytes > 0 {
589 (estimated_size as f32) / (total_vram_bytes as f32)
590 } else {
591 1.0
592 };
593
594 let (context_room, context_is_model_max) =
595 calculate_max_context(loader, &config, estimated_size, avail_vram_bytes, dtype)
596 .unwrap_or((0, false));
597
598 let candidate = TuneCandidate {
599 isq,
600 isq_name: isq_display_name(isq),
601 estimated_size_bytes: estimated_size,
602 vram_usage_percent: vram_usage,
603 max_context_tokens: context_room,
604 context_is_model_max,
605 quality: quality_tier(isq),
606 fit_status,
607 device_layers_cli,
608 recommended: false, };
610
611 tune_candidates.push(candidate);
612 }
613
614 for pref in &preferred_candidates {
616 if let Some(idx) = tune_candidates.iter().position(|c| {
617 c.isq == *pref && matches!(c.fit_status, FitStatus::Fits | FitStatus::Hybrid)
618 }) {
619 tune_candidates[idx].recommended = true;
620 recommended_idx = Some(idx);
621 break;
622 }
623 }
624
625 if recommended_idx.is_none() {
627 if let Some(idx) = tune_candidates
628 .iter()
629 .position(|c| matches!(c.fit_status, FitStatus::Fits | FitStatus::Hybrid))
630 {
631 tune_candidates[idx].recommended = true;
632 recommended_idx = Some(idx);
633 }
634 }
635
636 let (recommended_isq, device_layers, device_layers_cli, recommended_command) =
638 if let Some(idx) = recommended_idx {
639 let rec = &tune_candidates[idx];
640 let isq_flag = rec
641 .isq
642 .map(|i| format!(" --isq {:?}", i).to_lowercase())
643 .unwrap_or_default();
644 let cmd = format!("hanzo serve -m {model_id}{isq_flag}");
645 (rec.isq, None, rec.device_layers_cli.clone(), cmd)
646 } else {
647 (None, None, None, format!("hanzo serve -m {model_id}"))
648 };
649
650 let paged_attn_mode = if backend != TuneBackend::Cpu && paged_attn_supported() {
651 Some("auto".to_string())
652 } else {
653 Some("off".to_string())
654 };
655
656 for c in &tune_candidates {
658 if matches!(c.fit_status, FitStatus::TooLarge) && c.isq.is_some() {
659 warnings.push(format!(
660 "{} ({:.1} GB) exceeds available VRAM",
661 c.isq_name,
662 c.estimated_size_bytes as f64 / 1e9
663 ));
664 }
665 }
666
667 if recommended_idx.is_none() {
668 anyhow::bail!(
669 "No suitable quantization level fits on the available devices. Try a smaller model or enable CPU offload."
670 );
671 }
672
673 Ok(AutoTuneResult {
674 model_id,
675 profile: req.profile,
676 backend: backend_name(backend),
677 candidates: tune_candidates,
678 recommended_isq,
679 device_layers,
680 device_layers_cli,
681 paged_attn_mode,
682 recommended_command,
683 total_vram_bytes,
684 warnings,
685 notes,
686 })
687}