Skip to main content

llm_manager/
models.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6/// The state of a model in the manager.
7#[derive(Debug, Clone, PartialEq)]
8pub enum ModelState {
9    Available,
10    Loading,
11    Benchmarking,
12    Loaded { port: u16, pid: u32 },
13    Failed { error: String },
14}
15
16/// Sort order for search results.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum SearchSort {
19    Relevance,
20    Downloads,
21    Likes,
22    Trending,
23    CreatedAt,
24}
25
26impl SearchSort {
27    pub fn next(self) -> Self {
28        match self {
29            SearchSort::Relevance => SearchSort::Downloads,
30            SearchSort::Downloads => SearchSort::Likes,
31            SearchSort::Likes => SearchSort::Trending,
32            SearchSort::Trending => SearchSort::CreatedAt,
33            SearchSort::CreatedAt => SearchSort::Relevance,
34        }
35    }
36
37    pub fn label(self) -> &'static str {
38        match self {
39            SearchSort::Relevance => "Relevance",
40            SearchSort::Downloads => "Downloads",
41            SearchSort::Likes => "Likes",
42            SearchSort::Trending => "Trending",
43            SearchSort::CreatedAt => "Created",
44        }
45    }
46}
47
48/// A model found via HuggingFace search.
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct SearchResult {
51    pub model_id: String,
52    pub model_name: String,
53    pub tags: Vec<String>,
54    pub downloads: u64,
55    pub likes: u64,
56    pub pipeline_tag: Option<String>,
57    pub size: Option<u64>,
58    pub parameters: Option<String>,
59    pub capabilities: Vec<String>,
60    pub context_length: Option<u32>,
61    pub readme: Option<String>,
62    /// Quantization type extracted from GGUF metadata (e.g. "Q4_K_M", "Q8_0").
63    pub quantization: Option<String>,
64    /// License extracted from tags (e.g. "apache-2.0", "llama3.1").
65    pub license: Option<String>,
66    /// HuggingFace trending score.
67    pub trending_score: i64,
68    /// Creation timestamp string.
69    pub created_at: Option<String>,
70    /// Whether a matching GGUF file is already downloaded locally.
71    #[serde(default)]
72    pub downloaded: bool,
73}
74
75/// Download progress information.
76#[derive(Debug, Clone)]
77pub struct DownloadState {
78    pub model_id: String,
79    pub filename: String,
80    pub total_bytes: u64,
81    pub downloaded_bytes: u64,
82    pub status: DownloadStatus,
83    pub cancelled: bool,
84    pub cancel_token: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
85    /// Download control: 1=downloading, 2=paused, 3=cancelled
86    pub download_state: u8,
87    /// Shared atomic state for pausing/resuming the download loop
88    pub download_state_arc: Option<std::sync::Arc<std::sync::atomic::AtomicU8>>,
89    pub start_time: std::time::Instant,
90    pub bytes_per_second: f64,
91    /// Filesystem path where the download is being saved.
92    pub dest: Option<std::path::PathBuf>,
93}
94
95impl DownloadState {
96    pub fn new(model_id: String, filename: String, total_bytes: u64) -> Self {
97        Self {
98            model_id,
99            filename,
100            total_bytes,
101            downloaded_bytes: 0,
102            status: DownloadStatus::Downloading,
103            cancelled: false,
104            cancel_token: None,
105            download_state: 1,
106            download_state_arc: None,
107            start_time: std::time::Instant::now(),
108            bytes_per_second: 0.0,
109            dest: None,
110        }
111    }
112}
113
114impl ModelSettings {
115    /// Get the version string for the currently active backend.
116    pub fn get_active_backend_version(&self) -> Option<&String> {
117        match self.backend {
118            Backend::Cpu => self.llama_cpp_version_cpu.as_ref(),
119            Backend::Vulkan => self.llama_cpp_version_vulkan.as_ref(),
120            Backend::Rocm => self.llama_cpp_version_rocm.as_ref(),
121            Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade.as_ref(),
122            Backend::Cuda => self.llama_cpp_version_cuda.as_ref(),
123            _ => None,
124        }
125    }
126
127    /// Get the display version string for the currently active backend (defaults to "latest").
128    pub fn get_active_backend_version_display(&self) -> &str {
129        self.get_active_backend_version()
130            .map(|s| s.as_str())
131            .unwrap_or("latest")
132    }
133
134    /// Set the version string for the currently active backend.
135    pub fn set_active_backend_version(&mut self, tag: Option<String>) {
136        match self.backend {
137            Backend::Cpu => self.llama_cpp_version_cpu = tag,
138            Backend::Vulkan => self.llama_cpp_version_vulkan = tag,
139            Backend::Rocm => self.llama_cpp_version_rocm = tag,
140            Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade = tag,
141            Backend::Cuda => self.llama_cpp_version_cuda = tag,
142            _ => {}
143        }
144    }
145}
146
147/// Strip the .gguf extension from a model name.
148pub fn strip_gguf(name: &str) -> &str {
149    name.strip_suffix(".gguf")
150        .or_else(|| name.strip_suffix(".GGUF"))
151        .unwrap_or(name)
152}
153
154/// Ensure host string is valid for URL construction and CLI arguments.
155/// Handles empty strings (defaults to 127.0.0.1), strips display suffixes,
156/// and wraps IPv6 addresses in brackets.
157pub fn clean_host(host: &str) -> String {
158    let host = host.trim();
159    if host.is_empty() {
160        return "127.0.0.1".to_string();
161    }
162    // Remove (xxx) suffixes often used in display, e.g. "localhost (127.0.0.1)"
163    let host = host.split_whitespace().next().unwrap_or(host);
164    if host.contains(':') && !host.starts_with('[') {
165        format!("[{}]", host)
166    } else {
167        host.to_string()
168    }
169}
170
171/// Format a host string for display (e.g. "" or "127.0.0.1" -> "localhost (127.0.0.1)").
172pub fn format_host(host: &str) -> &str {
173    match host {
174        "" | "127.0.0.1" => "localhost (127.0.0.1)",
175        _ => host,
176    }
177}
178
179impl From<crate::config::DefaultParams> for ModelSettings {
180    fn from(dp: crate::config::DefaultParams) -> Self {
181        Self {
182            context_length: dp.context_length,
183            threads: dp.threads,
184            threads_batch: dp.threads_batch,
185            batch_size: dp.batch_size,
186            ubatch_size: dp.ubatch_size,
187            parallel: dp.parallel,
188            max_concurrent_predictions: dp.max_concurrent_predictions,
189            uniform_cache: dp.uniform_cache,
190            kv_cache_offload: dp.kv_cache_offload,
191            cache_type_k: dp.cache_type_k,
192            cache_type_v: dp.cache_type_v,
193            keep: dp.keep,
194            swa_full: dp.swa_full,
195            mlock: dp.mlock,
196            mmap: dp.mmap,
197            numa: dp.numa,
198            system_prompt: dp.system_prompt,
199            system_prompt_preset_name: dp.system_prompt_preset_name,
200
201            gpu_layers_mode: match dp.gpu_layers {
202                n if n < 0 => GpuLayersMode::All,
203                _ => dp.gpu_layers_mode,
204            },
205            split_mode: dp.split_mode,
206            tensor_split: dp.tensor_split,
207            main_gpu: dp.main_gpu,
208            fit: dp.fit,
209            lora: dp.lora,
210            lora_scaled: dp.lora_scaled,
211            rpc: dp.rpc,
212            embedding: dp.embedding,
213            flash_attn: dp.flash_attn,
214            expert_count: dp.expert_count,
215            jinja: dp.jinja,
216            chat_template: dp.chat_template,
217            chat_template_kwargs: dp.chat_template_kwargs,
218            seed: dp.seed,
219            temperature: dp.temperature,
220            top_k: dp.top_k,
221            top_p: dp.top_p,
222            min_p: dp.min_p,
223            typical_p: dp.typical_p,
224            mirostat: dp.mirostat,
225            mirostat_lr: dp.mirostat_lr,
226            mirostat_ent: dp.mirostat_ent,
227            ignore_eos: dp.ignore_eos,
228            samplers: dp.samplers,
229            repeat_penalty: dp.repeat_penalty,
230            repeat_last_n: dp.repeat_last_n,
231            presence_penalty: dp.presence_penalty,
232            frequency_penalty: dp.frequency_penalty,
233            dry_multiplier: dp.dry_multiplier,
234            dry_base: dp.dry_base,
235            dry_allowed_length: dp.dry_allowed_length,
236            dry_penalty_last_n: dp.dry_penalty_last_n,
237            rope_scaling: dp.rope_scaling,
238            rope_scale: dp.rope_scale,
239            rope_freq_base: dp.rope_freq_base,
240            rope_freq_scale: dp.rope_freq_scale,
241            rope_yarn_enabled: dp.rope_yarn_enabled,
242            host: dp.host,
243            port: dp.port,
244            timeout: dp.timeout,
245            cache_prompt: dp.cache_prompt,
246            cache_reuse: dp.cache_reuse,
247            webui: dp.webui,
248            max_tokens: dp.max_tokens,
249            cache_type: dp.cache_type,
250            backend: dp.backend,
251            llama_cpp_version_cpu: dp.llama_cpp_version_cpu,
252            llama_cpp_version_vulkan: dp.llama_cpp_version_vulkan,
253            llama_cpp_version_rocm: dp.llama_cpp_version_rocm,
254            llama_cpp_version_rocm_lemonade: dp.llama_cpp_version_rocm_lemonade,
255            llama_cpp_version_cuda: dp.llama_cpp_version_cuda,
256            api_endpoint_enabled: dp.api_endpoint_enabled,
257            api_endpoint_port: dp.api_endpoint_port,
258            ws_server_enabled: dp.ws_server_enabled,
259            ws_server_port: dp.ws_server_port,
260            ws_server_auth_key: dp.ws_server_auth_key,
261            ws_server_tls_enabled: dp.ws_server_tls_enabled,
262            ws_server_tls_cert: dp.ws_server_tls_cert,
263            ws_server_tls_key: dp.ws_server_tls_key,
264            spec_type: dp.spec_type,
265            draft_tokens: dp.draft_tokens,
266            tags: dp.tags,
267        }
268    }
269}
270
271/// How to handle GPU layer offloading.
272#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Hash)]
273#[derive(Default)]
274pub enum GpuLayersMode {
275    #[default]
276    Auto,
277    Specific(u32),
278    All,
279}
280
281
282#[derive(Debug, Clone, PartialEq, Eq)]
283pub enum DownloadStatus {
284    Downloading,
285    Paused,
286    Complete,
287    Error(String),
288    Cancelled,
289}
290
291// ── Cache type enums ──────────────────────────────────────────
292
293/// Main KV cache data type.
294#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
295pub enum CacheType {
296    #[serde(rename = "f16")]
297    #[default]
298    F16,
299    #[serde(rename = "bf16")]
300    BF16,
301    #[serde(rename = "fq8_0")]
302    Fq8_0,
303    #[serde(rename = "fq4_1")]
304    Fq4_1,
305}
306
307impl std::fmt::Display for CacheType {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        match self {
310            CacheType::F16 => write!(f, "f16"),
311            CacheType::BF16 => write!(f, "bf16"),
312            CacheType::Fq8_0 => write!(f, "fq8_0"),
313            CacheType::Fq4_1 => write!(f, "fq4_1"),
314        }
315    }
316}
317
318/// KV cache quantization type.
319#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default, Hash)]
320pub enum CacheQuantType {
321    #[serde(rename = "f32")]
322    F32,
323    #[serde(rename = "f16")]
324    #[default]
325    F16,
326    #[serde(rename = "bf16")]
327    BF16,
328    #[serde(rename = "q8_0")]
329    Q8_0,
330    #[serde(rename = "q4_0")]
331    Q4_0,
332    #[serde(rename = "q4_1")]
333    Q4_1,
334    #[serde(rename = "iq4_nl")]
335    Iq4Nl,
336    #[serde(rename = "q5_0")]
337    Q5_0,
338    #[serde(rename = "q5_1")]
339    Q5_1,
340}
341
342pub type CacheTypeK = CacheQuantType;
343pub type CacheTypeV = CacheQuantType;
344
345impl CacheQuantType {
346    pub fn from_u8(n: u8) -> Self {
347        match n {
348            0 => Self::F32,
349            1 => Self::F16,
350            2 => Self::BF16,
351            3 => Self::Q8_0,
352            4 => Self::Q5_1,
353            5 => Self::Q5_0,
354            6 => Self::Q4_1,
355            7 => Self::Q4_0,
356            8 => Self::Iq4Nl,
357            _ => Self::F16,
358        }
359    }
360    pub fn next(&self) -> Self {
361        match self {
362            Self::F32 => Self::F16,
363            Self::F16 => Self::BF16,
364            Self::BF16 => Self::Q8_0,
365            Self::Q8_0 => Self::Q5_1,
366            Self::Q5_1 => Self::Q5_0,
367            Self::Q5_0 => Self::Q4_1,
368            Self::Q4_1 => Self::Q4_0,
369            Self::Q4_0 => Self::Iq4Nl,
370            Self::Iq4Nl => Self::F32,
371        }
372    }
373    pub fn prev(&self) -> Self {
374        match self {
375            Self::F32 => Self::Iq4Nl,
376            Self::F16 => Self::F32,
377            Self::BF16 => Self::F16,
378            Self::Q8_0 => Self::BF16,
379            Self::Q5_1 => Self::Q8_0,
380            Self::Q5_0 => Self::Q5_1,
381            Self::Q4_1 => Self::Q5_0,
382            Self::Q4_0 => Self::Q4_1,
383            Self::Iq4Nl => Self::Q4_0,
384        }
385    }
386}
387
388impl std::fmt::Display for CacheQuantType {
389    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390        match self {
391            Self::F32 => write!(f, "f32"),
392            Self::F16 => write!(f, "f16"),
393            Self::BF16 => write!(f, "bf16"),
394            Self::Q8_0 => write!(f, "q8_0"),
395            Self::Q4_0 => write!(f, "q4_0"),
396            Self::Q4_1 => write!(f, "q4_1"),
397            Self::Iq4Nl => write!(f, "iq4_nl"),
398            Self::Q5_0 => write!(f, "q5_0"),
399            Self::Q5_1 => write!(f, "q5_1"),
400        }
401    }
402}
403
404impl From<&str> for CacheQuantType {
405    fn from(s: &str) -> Self {
406        match s {
407            "F32" => Self::F32,
408            "F16" => Self::F16,
409            "BF16" => Self::BF16,
410            "Q8_0" => Self::Q8_0,
411            "Q4_0" => Self::Q4_0,
412            "Q4_1" => Self::Q4_1,
413            "Iq4Nl" => Self::Iq4Nl,
414            "Q5_0" => Self::Q5_0,
415            "Q5_1" => Self::Q5_1,
416            _ => Self::F16, // Default or error handling
417        }
418    }
419}
420
421/// Split mode for multi-GPU.
422#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
423pub enum SplitMode {
424    #[serde(rename = "none")]
425    None,
426    #[serde(rename = "layer")]
427    #[default]
428    Layer,
429    #[serde(rename = "row")]
430    Row,
431    #[serde(rename = "tensor")]
432    Tensor,
433}
434
435impl std::fmt::Display for SplitMode {
436    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437        match self {
438            SplitMode::None => write!(f, "none"),
439            SplitMode::Layer => write!(f, "layer"),
440            SplitMode::Row => write!(f, "row"),
441            SplitMode::Tensor => write!(f, "tensor"),
442        }
443    }
444}
445
446/// NUMA optimization mode.
447#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
448pub enum NumMode {
449    #[serde(rename = "none")]
450    #[default]
451    None,
452    #[serde(rename = "distribute")]
453    Distribute,
454    #[serde(rename = "isolate")]
455    Isolate,
456    #[serde(rename = "numactl")]
457    Numactl,
458}
459
460impl std::fmt::Display for NumMode {
461    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462        match self {
463            NumMode::None => write!(f, "none"),
464            NumMode::Distribute => write!(f, "distribute"),
465            NumMode::Isolate => write!(f, "isolate"),
466            NumMode::Numactl => write!(f, "numactl"),
467        }
468    }
469}
470
471/// RoPE frequency scaling method.
472#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
473pub enum RopeScaling {
474    #[serde(rename = "none")]
475    #[default]
476    None,
477    #[serde(rename = "linear")]
478    Linear,
479    #[serde(rename = "yarn")]
480    Yarn,
481}
482
483impl std::fmt::Display for RopeScaling {
484    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485        match self {
486            RopeScaling::None => write!(f, "none"),
487            RopeScaling::Linear => write!(f, "linear"),
488            RopeScaling::Yarn => write!(f, "yarn"),
489        }
490    }
491}
492
493/// Mirostat version.
494#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
495pub enum Mirostat {
496    #[serde(rename = "0")]
497    #[default]
498    Off,
499    #[serde(rename = "1")]
500    V1,
501    #[serde(rename = "2")]
502    Mirostat2,
503}
504
505impl std::fmt::Display for Mirostat {
506    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507        match self {
508            Mirostat::Off => write!(f, "off"),
509            Mirostat::V1 => write!(f, "1"),
510            Mirostat::Mirostat2 => write!(f, "2"),
511        }
512    }
513}
514
515/// Sampler order string (semicolon-separated).
516/// Common types: penalties, dry, top_n_sigma, top_k, typ_p, top_p, min_p, xtc, temperature
517#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
518pub struct Samplers(pub String);
519
520impl Default for Samplers {
521    fn default() -> Self {
522        Self("penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature".to_string())
523    }
524}
525
526impl std::fmt::Display for Samplers {
527    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
528        write!(f, "{}", self.0)
529    }
530}
531
532/// Backend used to run the llama.cpp server.
533#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
534pub enum Backend {
535    #[serde(rename = "cpu")]
536    #[default]
537    Cpu,
538    #[serde(rename = "vulkan")]
539    Vulkan,
540    #[serde(rename = "rocm")]
541    Rocm,
542    #[serde(rename = "rocm_lemonade")]
543    RocmLemonade,
544    #[serde(rename = "cuda")]
545    Cuda,
546    #[serde(rename = "cpu_arm64")]
547    CpuArm64,
548    #[serde(rename = "win_cpu")]
549    CpuWindows,
550    #[serde(rename = "win_vulkan")]
551    VulkanWindows,
552    #[serde(rename = "win_cuda_12_4")]
553    CudaWindows12_4,
554    #[serde(rename = "win_cuda_13_1")]
555    CudaWindows13_1,
556    #[serde(rename = "win_hip")]
557    HipWindows,
558    #[serde(rename = "macos_arm64")]
559    CpuMacosArm64,
560    #[serde(rename = "macos_x64")]
561    CpuMacosX64,
562}
563
564impl Backend {
565    /// Get the identifier used for directory names and asset prefixes.
566    pub fn slug(&self) -> &'static str {
567        match self {
568            Backend::Cpu => "cpu",
569            Backend::Vulkan => "vulkan",
570            Backend::Rocm => "rocm",
571            Backend::RocmLemonade => "rocm-lemonade",
572            Backend::Cuda => "cuda",
573            Backend::CpuArm64 => "cpu-arm64",
574            Backend::CpuWindows => "win-cpu",
575            Backend::VulkanWindows => "win-vulkan",
576            Backend::CudaWindows12_4 => "win-cuda-12.4",
577            Backend::CudaWindows13_1 => "win-cuda-13.1",
578            Backend::HipWindows => "win-hip",
579            Backend::CpuMacosArm64 => "macos-arm64",
580            Backend::CpuMacosX64 => "macos-x64",
581        }
582    }
583
584    /// Returns true if this backend is for Linux.
585    pub fn is_linux(self) -> bool {
586        matches!(
587            self,
588            Backend::Cpu
589                | Backend::Vulkan
590                | Backend::Rocm
591                | Backend::RocmLemonade
592                | Backend::Cuda
593                | Backend::CpuArm64
594        )
595    }
596
597    /// Returns true if this backend is for Windows.
598    pub fn is_windows(self) -> bool {
599        matches!(
600            self,
601            Backend::CpuWindows
602                | Backend::VulkanWindows
603                | Backend::CudaWindows12_4
604                | Backend::CudaWindows13_1
605                | Backend::HipWindows
606        )
607    }
608
609    /// Returns true if this backend is for macOS.
610    pub fn is_macos(self) -> bool {
611        matches!(self, Backend::CpuMacosArm64 | Backend::CpuMacosX64)
612    }
613
614    /// Parse backend from string representation.
615    pub fn from_str(s: &str) -> Self {
616        let s = s.to_lowercase();
617        if s.starts_with("vulkan") || s.starts_with("vk") {
618            Backend::Vulkan
619        } else if s.starts_with("rocm") || s.starts_with("ro") {
620            if s.contains("lemonade") {
621                Backend::RocmLemonade
622            } else {
623                Backend::Rocm
624            }
625        } else if s.starts_with("cuda") || s.starts_with("cu") {
626            Backend::Cuda
627        } else {
628            Backend::Cpu // Default
629        }
630    }
631}
632
633impl std::fmt::Display for Backend {
634    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
635        write!(f, "{}", self.slug())
636    }
637}
638
639/// Server mode: normal (single model) or router (multiple models).
640#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
641pub enum ServerMode {
642    #[serde(rename = "normal")]
643    #[default]
644    Normal,
645    #[serde(rename = "router")]
646    Router,
647    #[serde(rename = "bench_gpu", alias = "bench")]
648    Bench,
649    #[serde(rename = "bench_tune")]
650    BenchTune,
651}
652
653impl std::fmt::Display for ServerMode {
654    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
655        match self {
656            ServerMode::Normal => write!(f, "Normal"),
657            ServerMode::Router => write!(f, "Router (XP!)"),
658            ServerMode::Bench => write!(f, "Bench GPU"),
659            ServerMode::BenchTune => write!(f, "BenchTune"),
660        }
661    }
662}
663
664// ── ModelSettings ─────────────────────────────────────────────
665//
666// WHEN ADDING A NEW PARAMETER to ModelSettings, update ALL of these locations:
667//   1.  src/models.rs:668       — ModelSettings struct field + doc comment
668//   2.  src/models.rs:179       — From<DefaultParams> for ModelSettings (field mapping)
669//   3.  src/config.rs:564       — DefaultParams struct field + serde attribute
670//   4.  src/config.rs:785       — DefaultParams Default impl (default value)
671//   5.  src/config.rs:175       — ModelOverride struct field (Option<T>)
672//   6.  src/config.rs:299       — ModelOverride::from_settings() (Some(s.field))
673//   7.  src/config.rs:385       — ModelOverride::apply() (one of the 3 macro calls)
674//   8.  src/tui/settings.rs:312 — all_fields() SettingField entry (id, name, section, etc.)
675//   9.  src/tui/settings.rs:1149— profile_settings_parts() diff macro call
676//  10.  src/tui/app/profiles.rs:80 — settings_fingerprint() hash call
677//  11.  src/tui/event/helpers.rs:125 — sync_global_settings() (if global-scoped)
678//  12.  src/tui/event/panel/settings.rs — key handlers for numeric/toggle edit
679//
680// The derived PartialEq on ModelSettings and DefaultParams guarantees
681// is_dirty() compares ALL fields — no field can be missed.
682// The build.rs script checks field counts at compile time.
683
684/// Settings for loading a model via llama.cpp server.
685#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
686pub struct ModelSettings {
687    // ── Loading ──────────────────────────────────────────────
688    /// Size of the prompt context.
689    pub context_length: u32,
690    /// Number of CPU threads for generation.
691    pub threads: u32,
692    /// Number of CPU threads for batch processing.
693    pub threads_batch: u32,
694    /// Logical maximum batch size.
695    pub batch_size: u32,
696    /// Physical maximum batch size (micro-batch).
697    pub ubatch_size: u32,
698    /// Max concurrent predictions (sequences).
699    pub parallel: u32,
700    /// Max concurrent predictions (requests in flight). None means no --parallel argument.
701    pub max_concurrent_predictions: Option<u32>,
702    /// Use uniform (unified) KV cache across all sequences.
703    pub uniform_cache: bool,
704    /// Offload KV cache to system RAM.
705    pub kv_cache_offload: bool,
706    /// KV cache data type for K.
707    pub cache_type_k: Option<CacheTypeK>,
708    /// KV cache data type for V.
709    pub cache_type_v: Option<CacheTypeV>,
710    /// Keep N tokens from the initial prompt.
711    pub keep: i32,
712    /// Use full-size SWA cache.
713    pub swa_full: bool,
714    /// Force system to keep model in RAM.
715    pub mlock: bool,
716    /// Memory-map the model.
717    pub mmap: bool,
718    /// NUMA optimization.
719    pub numa: NumMode,
720    /// System prompt.
721    pub system_prompt: String,
722    /// Name of the system prompt preset currently selected.
723    pub system_prompt_preset_name: String,
724
725    // ── GPU ──────────────────────────────────────────────────
726    /// GPU layer offloading mode.
727    pub gpu_layers_mode: GpuLayersMode,
728    /// Split mode across multiple GPUs.
729    pub split_mode: SplitMode,
730    /// Fraction of model offloaded to each GPU (comma-separated).
731    pub tensor_split: String,
732    /// Main GPU index.
733    pub main_gpu: i32,
734    /// Whether to adjust arguments to fit device memory.
735    pub fit: bool,
736    /// Path to LoRA adapter.
737    pub lora: Option<PathBuf>,
738    /// Path to LoRA adapter with scale.
739    pub lora_scaled: Option<(PathBuf, f32)>,
740    /// RPC servers.
741    pub rpc: String,
742    /// Restrict to embedding use case.
743    pub embedding: bool,
744    /// Enable Flash Attention.
745    pub flash_attn: bool,
746    /// Active experts per token (MoE models, -1 = model default).
747    pub expert_count: i32,
748    /// Use Jinja template engine for chat.
749    pub jinja: bool,
750    /// Custom chat template string.
751    pub chat_template: Option<String>,
752    /// JSON string for --chat-template-kwargs (e.g. {"enable_thinking": false}).
753    pub chat_template_kwargs: Option<String>,
754
755    // ── Sampling ─────────────────────────────────────────────
756    /// RNG seed (-1 = random).
757    pub seed: i32,
758    /// Temperature.
759    pub temperature: f32,
760    /// Top-k sampling (0 = disabled).
761    pub top_k: i32,
762    /// Top-p sampling (1.0 = disabled).
763    pub top_p: f32,
764    /// Minimum probability for a token.
765    pub min_p: f32,
766    /// Locally typical sampling parameter p.
767    pub typical_p: f32,
768    /// Mirostat version (0=off, 1=Mirostat, 2=Mirostat2).
769    pub mirostat: Mirostat,
770    /// Mirostat learning rate (eta).
771    pub mirostat_lr: f32,
772    /// Mirostat target entropy (tau).
773    pub mirostat_ent: f32,
774    /// Ignore end-of-stream token.
775    pub ignore_eos: bool,
776    /// Sampler order string.
777    pub samplers: Samplers,
778
779    // ── Repetition Control ───────────────────────────────────
780    /// Penalize repeat sequence of tokens.
781    pub repeat_penalty: f32,
782    /// Last N tokens to consider for repeat penalty.
783    pub repeat_last_n: i32,
784    /// Repeat alpha presence penalty.
785    pub presence_penalty: Option<f32>,
786    /// Repeat alpha frequency penalty.
787    pub frequency_penalty: Option<f32>,
788    /// DRY sampling multiplier.
789    pub dry_multiplier: f32,
790    /// DRY sampling base value.
791    pub dry_base: f32,
792    /// DRY allowed length.
793    pub dry_allowed_length: i32,
794    /// DRY penalty last N.
795    pub dry_penalty_last_n: i32,
796
797    // ── RoPE ─────────────────────────────────────────────────
798    /// RoPE frequency scaling method.
799    pub rope_scaling: RopeScaling,
800    /// RoPE context scaling factor.
801    pub rope_scale: f32,
802    /// RoPE base frequency.
803    pub rope_freq_base: f32,
804    /// RoPE frequency scaling factor.
805    pub rope_freq_scale: f32,
806    /// Enable Yarn RoPE scaling mode.
807    pub rope_yarn_enabled: bool,
808
809    // ── Server ───────────────────────────────────────────────
810    /// Host address.
811    pub host: String,
812    /// Port.
813    pub port: u16,
814    /// Server timeout in seconds.
815    pub timeout: u32,
816    /// Whether to enable prompt caching.
817    pub cache_prompt: bool,
818    /// Min chunk size for cache reuse.
819    pub cache_reuse: u32,
820    /// Whether to enable WebUI.
821    pub webui: bool,
822
823    // ── Other ────────────────────────────────────────────────
824    /// Max tokens to predict.
825    pub max_tokens: Option<u32>,
826    /// Cache type (legacy, kept for compatibility).
827    pub cache_type: CacheType,
828    /// Backend (cpu/vulkan).
829    pub backend: Backend,
830    /// llama.cpp release tag for CPU backend (e.g. "b1234" or None for latest).
831    pub llama_cpp_version_cpu: Option<String>,
832    /// llama.cpp release tag for Vulkan backend (e.g. "b1234" or None for latest).
833    pub llama_cpp_version_vulkan: Option<String>,
834    /// llama.cpp release tag for ROCm backend (e.g. "b1234" or None for latest).
835    pub llama_cpp_version_rocm: Option<String>,
836    /// Lemonade llama.cpp release tag for ROCm backend.
837    pub llama_cpp_version_rocm_lemonade: Option<String>,
838    /// llama.cpp release tag for CUDA backend.
839    pub llama_cpp_version_cuda: Option<String>,
840    /// Whether to enable the API proxy server.
841    pub api_endpoint_enabled: bool,
842    /// Port for the API proxy server.
843    pub api_endpoint_port: u16,
844    /// Speculative decoding type (e.g., "draft-mtp", "ngram-simple", "" for off).
845    pub spec_type: String,
846    /// Number of draft tokens for MTP.
847    pub draft_tokens: u32,
848    /// Tags for the model.
849    pub tags: Vec<String>,
850    /// Whether to enable the WebSocket dashboard server.
851    pub ws_server_enabled: bool,
852    pub ws_server_port: u16,
853    pub ws_server_auth_key: Option<String>,
854    pub ws_server_tls_enabled: bool,
855    pub ws_server_tls_cert: Option<String>,
856    pub ws_server_tls_key: Option<String>,
857}
858
859impl Default for ModelSettings {
860    fn default() -> Self {
861        let mut s: Self = crate::config::DefaultParams::default().into();
862        // Override fields that differ from DefaultParams defaults
863        s.uniform_cache = false;
864        s.cache_type_k = Some(CacheTypeK::F16);
865        s.cache_type_v = Some(CacheTypeV::F16);
866        s.cache_type = CacheType::default();
867        s.backend = Backend::Cpu;
868        s.presence_penalty = Some(0.0);
869        s.frequency_penalty = Some(0.0);
870        s
871    }
872}
873
874impl ModelSettings {
875    /// Create ModelSettings from config defaults, applying model-specific overrides.
876    pub fn from_config(config: &crate::config::Config) -> Self {
877        let mut settings: ModelSettings = config.default.clone().into();
878        settings.ws_server_enabled = config.ws_server.enabled;
879        settings.ws_server_port = config.ws_server.port;
880        settings.ws_server_auth_key = config.ws_server.auth_key.clone();
881        settings.ws_server_tls_enabled = config.ws_server.tls_enabled;
882        settings.ws_server_tls_cert = config.ws_server.tls_cert.clone();
883        settings.ws_server_tls_key = config.ws_server.tls_key.clone();
884        settings
885    }
886}
887
888/// Default benchmark prompt used when starting a tuning session.
889pub const BENCHMARK_PROMPT: &str = "Create Mona Lisa image in ascii art using text, number, symbol, everything possible. this should be the perfect painting.";
890
891/// A discovered model file.
892#[derive(Debug, Clone)]
893pub struct DiscoveredModel {
894    pub path: PathBuf,
895    pub name: String,
896    pub file_size: u64,
897    pub display_name: String, // path relative to model_dir for display
898}
899
900/// Parsed GGUF metadata for a model, cached to avoid re-parsing the file.
901#[derive(Debug, Clone, Default)]
902pub struct GgufMetadata {
903    pub layers: u32,
904    pub hidden_size: u32,
905    pub n_ctx_train: u32,
906    pub n_head: u32,
907    pub n_kv_head: u32,
908    pub arch: String,
909    pub file_type: String,
910    pub quantization: String,
911    pub model_parameters: String,
912    pub domain: String,
913    pub capabilities: Vec<String>,
914    pub tokenizer: String,
915    pub vocab_size: u32,
916    pub draft_tokens: u32,
917}
918
919impl GgufMetadata {
920    pub fn from_path(path: &std::path::Path) -> anyhow::Result<Self> {
921        let path_str = path.to_string_lossy();
922        let mut container = gguf_rs::get_gguf_container(&path_str)
923            .map_err(|e| anyhow::anyhow!("Failed to get GGUF container: {}", e))?;
924        let model_data = container
925            .decode()
926            .map_err(|e| anyhow::anyhow!("Failed to decode GGUF: {}", e))?;
927
928        let mut meta = Self::default();
929
930        let extract_str = |key: &str| -> String {
931            model_data
932                .metadata()
933                .get(key)
934                .and_then(|v| v.as_str().map(|s| s.to_string()))
935                .unwrap_or_default()
936        };
937
938        let extract_num = |key: &str| -> Option<u64> {
939            model_data.metadata().get(key).and_then(|v| {
940                v.as_u64()
941                    .or_else(|| v.as_i64().map(|x| x as u64))
942                    .or_else(|| v.as_f64().map(|x| x as u64))
943            })
944        };
945
946        meta.arch = extract_str("general.architecture");
947        let prefix = if meta.arch.is_empty() {
948            "llama"
949        } else {
950            &meta.arch
951        };
952
953        let get_num_with_fallback = |suffix: &str| -> u32 {
954            extract_num(&format!("{}.{}", prefix, suffix))
955                .or_else(|| {
956                    if prefix != "llama" {
957                        extract_num(&format!("llama.{}", suffix))
958                    } else {
959                        None
960                    }
961                })
962                .unwrap_or(0) as u32
963        };
964
965        meta.layers = get_num_with_fallback("block_count");
966        meta.hidden_size = get_num_with_fallback("embedding_length");
967        meta.n_ctx_train = get_num_with_fallback("context_length");
968        meta.n_head = get_num_with_fallback("attention.head_count");
969        meta.n_kv_head = get_num_with_fallback("attention.head_count_kv");
970
971        if let Some(value) = model_data.metadata().get("tokenizer.ggml.tokens")
972            && let Some(arr) = value.as_array()
973        {
974            meta.vocab_size = arr.len() as u32;
975        }
976
977        if meta.arch == "mtp" {
978            meta.draft_tokens = extract_num("mtp.draft_tokens").unwrap_or(0) as u32;
979        }
980
981        if let Some(v) = extract_num("general.file_type") {
982            meta.file_type = match v {
983                0 => "F32".to_string(),
984                1 => "F16".to_string(),
985                2 => "Q4_0".to_string(),
986                3 => "Q4_1".to_string(),
987                7 => "Q8_0".to_string(),
988                8 => "Q5_0".to_string(),
989                9 => "Q5_1".to_string(),
990                10 => "Q2_K".to_string(),
991                11 => "Q3_K_S".to_string(),
992                12 => "Q3_K_M".to_string(),
993                13 => "Q3_K_L".to_string(),
994                14 => "Q4_K_S".to_string(),
995                15 => "Q4_K_M".to_string(),
996                16 => "Q5_K_S".to_string(),
997                17 => "Q5_K_M".to_string(),
998                18 => "Q6_K".to_string(),
999                19 => "IQ2_XXS".to_string(),
1000                20 => "IQ2_XS".to_string(),
1001                21 => "IQ3_XXS".to_string(),
1002                22 => "IQ1_S".to_string(),
1003                23 => "IQ4_NL".to_string(),
1004                24 => "IQ3_S".to_string(),
1005                25 => "IQ2_S".to_string(),
1006                26 => "IQ4_XS".to_string(),
1007                _ => format!("Unknown ({})", v),
1008            };
1009        }
1010
1011        if let Some(value) = model_data.metadata().get("general.capabilities")
1012            && let Some(arr) = value.as_array()
1013        {
1014            for v in arr {
1015                if let Some(s) = v.as_str() {
1016                    meta.capabilities.push(s.to_string());
1017                }
1018            }
1019        }
1020
1021        if model_data
1022            .metadata()
1023            .contains_key("tokenizer.chat_template")
1024        {
1025            meta.capabilities.push("chat".to_string());
1026        }
1027
1028        meta.tokenizer = extract_str("tokenizer.ggml.model");
1029        meta.domain = extract_str("general.domain");
1030        meta.model_parameters = model_data.model_parameters();
1031
1032        Ok(meta)
1033    }
1034}
1035
1036/// Metrics reported by the llama.cpp server.
1037#[derive(Debug, Clone)]
1038pub struct ServerMetrics {
1039    pub loaded: bool,
1040    pub tps: f64,
1041    pub prompt_tps: f64,
1042    pub cpu_usage: f64,
1043    pub gpu_mem_used: u64,
1044    pub gpu_mem_total: u64,
1045    pub ram_used: u64,
1046    pub ctx_used: u32,
1047    pub ctx_max: u32,
1048    /// Sum of gpu_mem_used across all loaded models (for Total VRAM display).
1049    pub total_vram_used: u64,
1050    /// Number of decoded tokens from print_timing logs.
1051    pub decoded_tokens: u64,
1052    /// Generation tokens per second parsed from llama.cpp log output (e.g., "tg = 64.45 t/s").
1053    pub gen_tps: f64,
1054    /// Estimated latency per generated token in milliseconds.
1055    pub latency_per_token_ms: f64,
1056    /// Estimated prompt processing latency in milliseconds (1000 / prompt_tps).
1057    pub prompt_latency_ms: f64,
1058}
1059
1060/// GPU device buffer reported by llama-server during model loading.
1061#[derive(Debug, Clone)]
1062pub struct GPUBuffer {
1063    pub device: String,
1064    pub buffer_size_mib: f64,
1065}
1066
1067/// Progress information during model loading, parsed from llama-server log output.
1068#[derive(Debug, Clone, Default)]
1069pub struct LoadProgress {
1070    /// Total number of layers in the model.
1071    pub layers_total: Option<u32>,
1072    /// Number of layers already offloaded to GPU.
1073    pub layers_loaded: Option<u32>,
1074    /// Total number of tensors in the model (from "Loading tensor X of Y" log).
1075    pub tensors_total: Option<u32>,
1076    /// Number of tensors loaded (counted from dot-lines in log).
1077    pub tensors_loaded: u32,
1078    /// GPU device buffers with their sizes.
1079    pub buffers: Vec<GPUBuffer>,
1080}
1081
1082impl Default for ServerMetrics {
1083    fn default() -> Self {
1084        Self {
1085            loaded: false,
1086            tps: 0.0,
1087            prompt_tps: 0.0,
1088            cpu_usage: 0.0,
1089            gpu_mem_used: 0,
1090            gpu_mem_total: 0,
1091            ram_used: 0,
1092            ctx_used: 0,
1093            ctx_max: 0,
1094            total_vram_used: 0,
1095            decoded_tokens: 0,
1096            gen_tps: 0.0,
1097            latency_per_token_ms: 0.0,
1098            prompt_latency_ms: 0.0,
1099        }
1100    }
1101}
1102
1103/// WebSocket-friendly metrics snapshot (serializable, no internal state).
1104#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1105pub struct WsMetrics {
1106    pub model_name: String,
1107    pub loaded: bool,
1108    pub state: String,
1109    pub tps: f64,
1110    pub prompt_tps: f64,
1111    pub ctx_used: u32,
1112    pub ctx_max: u32,
1113    pub cpu_usage: f64,
1114    pub gpu_mem_used: u64,
1115    pub gpu_mem_total: u64,
1116    pub ram_used: u64,
1117    pub latency_per_token_ms: f64,
1118    pub decoded_tokens: u64,
1119    pub gen_tps: f64,
1120    pub timestamp: u64,
1121    // Server command
1122    pub cmd_display: Option<String>,
1123    // LLM settings
1124    pub threads: u32,
1125    pub threads_batch: u32,
1126    pub context_length: u32,
1127    pub ubatch_size: u32,
1128    pub batch_size: u32,
1129    pub temperature: f32,
1130    pub top_k: u32,
1131    pub top_p: f32,
1132    pub min_p: f32,
1133    pub typical_p: f32,
1134    pub seed: i32,
1135    pub repeat_penalty: f32,
1136    pub repeat_last_n: i32,
1137    pub presence_penalty: Option<f32>,
1138    pub frequency_penalty: Option<f32>,
1139    pub mirostat: Option<u32>,
1140    pub mirostat_lr: Option<f32>,
1141    pub mirostat_ent: Option<f32>,
1142    pub max_tokens: Option<u32>,
1143    pub flash_attn: bool,
1144    pub kv_cache_offload: bool,
1145    pub cache_type_k: Option<String>,
1146    pub cache_type_v: Option<String>,
1147    pub uniform_cache: bool,
1148    pub mlock: bool,
1149    pub mmap: bool,
1150    pub embedding: bool,
1151    pub jinja: bool,
1152    pub ignore_eos: bool,
1153    pub samplers: String,
1154    pub expert_count: u32,
1155    pub gpu_layers: String,
1156    pub backend: String,
1157    pub llama_cpp_version: String,
1158    pub spec_type: String,
1159    pub draft_tokens: u32,
1160}
1161
1162impl WsMetrics {
1163    pub fn from_metrics(
1164        metrics: &ServerMetrics,
1165        model_name: &str,
1166        state: &str,
1167        settings: &crate::models::ModelSettings,
1168        cmd_display: Option<&str>,
1169    ) -> Self {
1170        use std::time::{SystemTime, UNIX_EPOCH};
1171        let timestamp = SystemTime::now()
1172            .duration_since(UNIX_EPOCH)
1173            .map(|d| d.as_secs())
1174            .unwrap_or(0);
1175        let gpu_layers = match settings.gpu_layers_mode {
1176            crate::models::GpuLayersMode::Auto => "Auto".to_string(),
1177            crate::models::GpuLayersMode::Specific(n) => n.to_string(),
1178            crate::models::GpuLayersMode::All => "All".to_string(),
1179        };
1180        Self {
1181            model_name: model_name.to_string(),
1182            loaded: metrics.loaded,
1183            state: state.to_string(),
1184            tps: metrics.tps,
1185            prompt_tps: metrics.prompt_tps,
1186            ctx_used: metrics.ctx_used,
1187            ctx_max: metrics.ctx_max,
1188            cpu_usage: metrics.cpu_usage,
1189            gpu_mem_used: metrics.gpu_mem_used,
1190            gpu_mem_total: metrics.gpu_mem_total,
1191            ram_used: metrics.ram_used,
1192            latency_per_token_ms: metrics.latency_per_token_ms,
1193            decoded_tokens: metrics.decoded_tokens,
1194            gen_tps: metrics.gen_tps,
1195            timestamp,
1196            cmd_display: cmd_display.map(String::from),
1197            threads: settings.threads,
1198            threads_batch: settings.threads_batch,
1199            context_length: settings.context_length,
1200            ubatch_size: settings.ubatch_size,
1201            batch_size: settings.batch_size,
1202            temperature: settings.temperature,
1203            top_k: settings.top_k as u32,
1204            top_p: settings.top_p,
1205            min_p: settings.min_p,
1206            typical_p: settings.typical_p,
1207            seed: settings.seed,
1208            repeat_penalty: settings.repeat_penalty,
1209            repeat_last_n: settings.repeat_last_n,
1210            presence_penalty: settings.presence_penalty,
1211            frequency_penalty: settings.frequency_penalty,
1212            mirostat: Some(match settings.mirostat {
1213                crate::models::Mirostat::Off => 0,
1214                crate::models::Mirostat::V1 => 1,
1215                crate::models::Mirostat::Mirostat2 => 2,
1216            }),
1217            mirostat_lr: Some(settings.mirostat_lr),
1218            mirostat_ent: Some(settings.mirostat_ent),
1219            max_tokens: settings.max_tokens,
1220            flash_attn: settings.flash_attn,
1221            kv_cache_offload: settings.kv_cache_offload,
1222            cache_type_k: settings.cache_type_k.map(|k| k.to_string()),
1223            cache_type_v: settings.cache_type_v.map(|k| k.to_string()),
1224            uniform_cache: settings.uniform_cache,
1225            mlock: settings.mlock,
1226            mmap: settings.mmap,
1227            embedding: settings.embedding,
1228            jinja: settings.jinja,
1229            ignore_eos: settings.ignore_eos,
1230            samplers: settings.samplers.to_string(),
1231            expert_count: settings.expert_count as u32,
1232            gpu_layers,
1233            backend: settings.backend.to_string(),
1234            llama_cpp_version: settings.get_active_backend_version_display().to_string(),
1235            spec_type: settings.spec_type.clone(),
1236            draft_tokens: settings.draft_tokens,
1237        }
1238    }
1239}
1240
1241/// Estimate VRAM usage (in MiB) for a model with the given settings.
1242///
1243/// Model file size is the size of the GGUF file in MiB. The model itself
1244/// takes 1x its size in VRAM (loaded as-is). KV cache is the dominant
1245/// variable cost — it scales with context_length, batch_size, and layers.
1246///
1247/// The KV cache formula accounts for:
1248/// - Actual GQA ratio from model metadata (n_kv_head / n_head)
1249/// - FlashAttention: reduces KV cache storage by ~2x
1250/// - Unified KV cache: shares KV across sequences, dividing by parallel count
1251/// - KV cache quantization (q4_0, q5_0, q8_0, etc.)
1252pub fn estimate_vram_mib(
1253    model_mib: u64,
1254    settings: &ModelSettings,
1255    total_layers: u32,
1256    hidden_size_opt: Option<u32>,
1257    n_head_opt: Option<u32>,
1258    n_kv_head_opt: Option<u32>,
1259    gpu_mem_total_mib: u64,
1260) -> u64 {
1261    let model_mib_f = model_mib as f64;
1262
1263    // Compute how much of the model is loaded into VRAM based on GPU layers.
1264    let gpu_layers = match settings.gpu_layers_mode {
1265        GpuLayersMode::Auto => {
1266            // Heuristic: ~60% of layers when Auto (llama.cpp will decide at runtime)
1267            if total_layers > 0 {
1268                (total_layers as f64 * 0.6) as u32
1269            } else {
1270                20
1271            }
1272        }
1273        GpuLayersMode::Specific(n) => {
1274            if total_layers > 0 {
1275                n.min(total_layers)
1276            } else {
1277                n
1278            }
1279        }
1280        GpuLayersMode::All => {
1281            if total_layers > 0 {
1282                total_layers
1283            } else {
1284                32
1285            }
1286        }
1287    };
1288
1289    // Model weights loaded into VRAM: proportional to GPU layers.
1290    let model_vram = if total_layers > 0 && gpu_layers > 0 {
1291        model_mib_f * (gpu_layers as f64 / total_layers as f64).min(1.0)
1292    } else if gpu_layers > 0 {
1293        model_mib_f
1294    } else {
1295        0.0
1296    };
1297
1298    if matches!(settings.gpu_layers_mode, GpuLayersMode::Specific(0)) {
1299        return 0; // CPU only
1300    }
1301
1302    // Heuristic for hidden_size if not provided:
1303    // A 7B model (4-bit) is ~4000 MiB and has hidden=4096.
1304    let hidden_size = match hidden_size_opt {
1305        Some(h) => h as f64,
1306        None => {
1307            let params_est = model_mib_f / 550.0;
1308            (1024.0 * params_est.sqrt().max(1.0) * 1.5).max(512.0)
1309        }
1310    };
1311
1312    // ── KV cache estimation ─────────────────────────────────────
1313
1314    // GQA ratio: real KV heads vs query heads.
1315    // If n_kv_head == n_head, ratio = 1.0 (no reduction).
1316    // If n_kv_head < n_head, ratio < 1.0 (KV cache is smaller).
1317    let gqa_ratio = match (n_head_opt, n_kv_head_opt) {
1318        (Some(n_head), Some(n_kv_head)) if n_head > 0 => n_kv_head as f64 / n_head as f64,
1319        _ => 1.0, // fallback: assume no GQA
1320    };
1321
1322    // FlashAttention reduces KV cache storage by ~2x because it doesn't
1323    // need to keep the full attention matrix in memory.
1324    let flash_attn_factor = if settings.flash_attn { 0.5 } else { 1.0 };
1325
1326    // Unified KV cache shares a single KV buffer across all sequences.
1327    let uniform_cache_factor = if settings.uniform_cache {
1328        1.0 / settings.parallel as f64
1329    } else {
1330        1.0
1331    };
1332
1333    // KV cache in MiB:
1334    // Formula: 2 * n_layer * n_ctx * n_embd_kv * sizeof(type)
1335    // n_embd_kv = hidden_size * gqa_ratio
1336    // The KV cache is allocated for the total number of model layers,
1337    // not just the number layers loaded into the GPU (gpu_layers).
1338    // However only gpu_layers * sizeof(type) contributes to the VRAM cost.
1339    let kv_mib = (2.0
1340        * hidden_size
1341        * settings.context_length as f64
1342        * total_layers as f64
1343        * gqa_ratio
1344        * gpu_layers as f64
1345        / total_layers as f64  // VRAM cost: only GPU-loaded portion of KV cache
1346        * flash_attn_factor
1347        * uniform_cache_factor
1348        * kv_quant_bytes(
1349            settings.cache_type_k.unwrap_or(CacheTypeK::F16),
1350            settings.cache_type_v.unwrap_or(CacheTypeV::F16)
1351        ))
1352        / (1024.0 * 1024.0);
1353
1354    // Activation overhead during inference (proportional to batch * hidden).
1355    // Increased multiplier to 8.0 (from 2.0) to be more pessimistic about scratch buffers.
1356    let activation_mib = (settings.batch_size as f64 * hidden_size * 8.0) / (1024.0 * 1024.0);
1357
1358    // Fixed overhead for driver, fragmentation, and small meta buffers.
1359    // Use 3.8% of max VRAM, falling back to 500MiB if unknown.
1360    let fixed_overhead = if gpu_mem_total_mib > 0 {
1361        gpu_mem_total_mib as f64 * 0.038
1362    } else {
1363        500.0
1364    };
1365
1366    let total_mib = model_vram + kv_mib + activation_mib + fixed_overhead + 550.0;
1367
1368    total_mib.ceil() as u64
1369}
1370
1371/// Return the average KV cache element size in bytes for the given K/V types.
1372///
1373/// KV cache stores K and V separately, potentially at different precisions.
1374/// We average the two to get a single per-element size.
1375fn kv_quant_bytes(k_type: CacheQuantType, v_type: CacheQuantType) -> f64 {
1376    let get_bytes = |t: CacheQuantType| match t {
1377        CacheQuantType::F32 => 4.0,
1378        CacheQuantType::F16 | CacheQuantType::BF16 => 2.0,
1379        CacheQuantType::Q8_0 => 1.0,
1380        CacheQuantType::Q5_0 | CacheQuantType::Q5_1 => 0.625, // 5 bits
1381        CacheQuantType::Q4_0 | CacheQuantType::Q4_1 | CacheQuantType::Iq4Nl => 0.5, // 4 bits
1382    };
1383    (get_bytes(k_type) + get_bytes(v_type)) / 2.0
1384}
1385
1386impl ModelSettings {
1387    /// Check if this settings differs from `other` in any field.
1388    /// Uses derived PartialEq which compares all fields — compiler-enforced.
1389    pub fn is_dirty(&self, other: &Self) -> bool {
1390        self != other
1391    }
1392}
1393
1394// Benchmark Tuning types
1395#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1396pub struct BenchTuneConfig {
1397    pub model_path: PathBuf,
1398    pub num_iterations: u32,
1399    pub prompt: String,
1400    pub params_to_test: Vec<BenchTuneParam>,
1401    pub test_duration: Duration,
1402    pub bench_mode: BenchTuneMode,
1403    pub n_predict: u32,
1404    pub chat_template_kwargs: Option<String>,
1405}
1406
1407#[derive(Debug, Clone, Serialize, Deserialize)]
1408pub struct BenchTuneParam {
1409    pub name: String,
1410    pub min: f64,
1411    pub max: f64,
1412    pub step: f64,
1413    pub enabled: bool,
1414}
1415
1416impl PartialEq for BenchTuneParam {
1417    fn eq(&self, other: &Self) -> bool {
1418        self.name == other.name
1419            && self.min.to_bits() == other.min.to_bits()
1420            && self.max.to_bits() == other.max.to_bits()
1421            && self.step.to_bits() == other.step.to_bits()
1422            && self.enabled == other.enabled
1423    }
1424}
1425impl Eq for BenchTuneParam {}
1426
1427#[derive(Debug, Clone, Serialize, Deserialize)]
1428pub struct BenchTuneParamValue {
1429    pub temperature: Option<f64>,
1430    pub top_p: Option<f64>,
1431    pub top_k: Option<i64>,
1432    pub repeat_penalty: Option<f64>,
1433    pub context_length: Option<u32>,
1434    pub batch_size: Option<u32>,
1435    pub flash_attn: Option<bool>,
1436    pub threads: Option<u32>,
1437    pub expert_count: Option<i32>,
1438    pub spec_type: Option<String>,
1439    pub draft_tokens: Option<u32>,
1440}
1441
1442impl PartialEq for BenchTuneParamValue {
1443    fn eq(&self, other: &Self) -> bool {
1444        self.temperature.map(|v| v.to_bits()) == other.temperature.map(|v| v.to_bits())
1445            && self.top_p.map(|v| v.to_bits()) == other.top_p.map(|v| v.to_bits())
1446            && self.top_k == other.top_k
1447            && self.repeat_penalty.map(|v| v.to_bits()) == other.repeat_penalty.map(|v| v.to_bits())
1448            && self.context_length == other.context_length
1449            && self.batch_size == other.batch_size
1450            && self.flash_attn == other.flash_attn
1451            && self.threads == other.threads
1452            && self.expert_count == other.expert_count
1453            && self.spec_type == other.spec_type
1454            && self.draft_tokens == other.draft_tokens
1455    }
1456}
1457impl Eq for BenchTuneParamValue {}
1458
1459#[derive(Debug, Clone, Serialize, Deserialize)]
1460pub struct BenchTuneResult {
1461    pub params: BenchTuneParamValue,
1462    pub metrics: BenchTuneMetrics,
1463    pub outputs: Vec<String>,
1464    pub per_iteration_metrics: Vec<BenchTuneMetrics>,
1465    pub base_settings: Option<ModelSettings>,
1466}
1467
1468#[derive(Debug, Clone, Serialize, Deserialize)]
1469pub struct BenchTuneMetrics {
1470    pub prompt_tps: f64,
1471    pub generation_tps: f64,
1472    pub combined_tps: f64,
1473    pub latency_per_token: f64,
1474    pub first_token_time: f64,
1475}
1476
1477#[derive(Debug, Clone, Serialize, Deserialize)]
1478pub enum BenchTuneStatus {
1479    Running {
1480        current: usize,
1481        total: usize,
1482        progress: f32,
1483        current_params: BenchTuneParamValue,
1484    },
1485    Completed {
1486        total_tests: usize,
1487        successful_tests: usize,
1488        elapsed: Duration,
1489    },
1490    PartiallyCompleted {
1491        total_tests: usize,
1492        successful_tests: usize,
1493        failed_tests: usize,
1494        elapsed: Duration,
1495    },
1496    Cancelled {
1497        total_tests: usize,
1498        successful_tests: usize,
1499        failed_tests: usize,
1500        elapsed: Duration,
1501    },
1502    Error {
1503        error: String,
1504    },
1505}
1506
1507#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1508#[derive(Default)]
1509pub enum BenchTuneMode {
1510    /// Runtime-only mode: sends all params in /completion request body, no server restarts
1511    RuntimeOnly,
1512    /// Full mode: spawns a new server for each parameter combination (tests server-level params)
1513    #[default]
1514    Full,
1515}
1516
1517
1518/// Progress status for benchmark tuning
1519#[derive(Debug, Clone)]
1520pub enum BenchTuneProgress {
1521    /// Tuning is running.
1522    Running {
1523        current: usize,
1524        total: usize,
1525        progress: f32,
1526        current_params: BenchTuneParamValue,
1527    },
1528    /// Tuning is complete.
1529    Completed {
1530        total_tests: usize,
1531        successful_tests: usize,
1532        elapsed: Duration,
1533    },
1534    /// Tuning completed with some failures.
1535    PartiallyCompleted {
1536        total_tests: usize,
1537        successful_tests: usize,
1538        failed_tests: usize,
1539        elapsed: Duration,
1540    },
1541    /// Tuning was cancelled by the user.
1542    Cancelled {
1543        total_tests: usize,
1544        successful_tests: usize,
1545        failed_tests: usize,
1546        elapsed: Duration,
1547    },
1548    /// Tuning failed.
1549    Error { error: String },
1550}
1551
1552impl BenchTuneProgress {
1553    pub fn from_status(status: &BenchTuneStatus) -> Option<Self> {
1554        match status {
1555            BenchTuneStatus::Running {
1556                current,
1557                total,
1558                progress,
1559                current_params,
1560            } => Some(BenchTuneProgress::Running {
1561                current: *current,
1562                total: *total,
1563                progress: *progress,
1564                current_params: current_params.clone(),
1565            }),
1566            BenchTuneStatus::Completed {
1567                total_tests,
1568                successful_tests,
1569                elapsed,
1570            } => Some(BenchTuneProgress::Completed {
1571                total_tests: *total_tests,
1572                successful_tests: *successful_tests,
1573                elapsed: *elapsed,
1574            }),
1575            BenchTuneStatus::PartiallyCompleted {
1576                total_tests,
1577                successful_tests,
1578                failed_tests,
1579                elapsed,
1580            } => Some(BenchTuneProgress::PartiallyCompleted {
1581                total_tests: *total_tests,
1582                successful_tests: *successful_tests,
1583                failed_tests: *failed_tests,
1584                elapsed: *elapsed,
1585            }),
1586            BenchTuneStatus::Cancelled {
1587                total_tests,
1588                successful_tests,
1589                failed_tests,
1590                elapsed,
1591            } => Some(BenchTuneProgress::Cancelled {
1592                total_tests: *total_tests,
1593                successful_tests: *successful_tests,
1594                failed_tests: *failed_tests,
1595                elapsed: *elapsed,
1596            }),
1597            BenchTuneStatus::Error { error } => Some(BenchTuneProgress::Error {
1598                error: error.clone(),
1599            }),
1600        }
1601    }
1602}
1603
1604impl BenchTuneConfig {
1605    pub fn new(model_path: PathBuf, num_iterations: u32, prompt: String) -> Self {
1606        Self {
1607            model_path,
1608            num_iterations,
1609            prompt,
1610            params_to_test: vec![
1611                BenchTuneParam {
1612                    name: "temperature".to_string(),
1613                    min: 0.4,
1614                    max: 1.0,
1615                    step: 0.1,
1616                    enabled: false,
1617                },
1618                BenchTuneParam {
1619                    name: "top_p".to_string(),
1620                    min: 0.8,
1621                    max: 1.0,
1622                    step: 0.1,
1623                    enabled: false,
1624                },
1625                BenchTuneParam {
1626                    name: "top_k".to_string(),
1627                    min: 10.0,
1628                    max: 40.0,
1629                    step: 5.0,
1630                    enabled: false,
1631                },
1632                BenchTuneParam {
1633                    name: "repeat_penalty".to_string(),
1634                    min: 1.0,
1635                    max: 1.2,
1636                    step: 0.1,
1637                    enabled: false,
1638                },
1639                BenchTuneParam {
1640                    name: "flash_attn".to_string(),
1641                    min: 0.0,
1642                    max: 1.0,
1643                    step: 1.0,
1644                    enabled: false,
1645                },
1646                BenchTuneParam {
1647                    name: "threads".to_string(),
1648                    min: 4.0,
1649                    max: 16.0,
1650                    step: 4.0,
1651                    enabled: false,
1652                },
1653                BenchTuneParam {
1654                    name: "batch_size".to_string(),
1655                    min: 512.0,
1656                    max: 2048.0,
1657                    step: 512.0,
1658                    enabled: false,
1659                },
1660                BenchTuneParam {
1661                    name: "expert_count".to_string(),
1662                    min: 1.0,
1663                    max: 4.0,
1664                    step: 1.0,
1665                    enabled: false,
1666                },
1667            ],
1668            test_duration: Duration::from_secs(30),
1669            bench_mode: BenchTuneMode::default(),
1670            n_predict: 512,
1671            chat_template_kwargs: Some(r#"{"enable_thinking": false}"#.to_string()),
1672        }
1673    }
1674
1675    /// Generate all parameter combinations based on the config
1676    pub fn generate_combinations(&self) -> Vec<BenchTuneParamValue> {
1677        let mut temp_values = vec![None];
1678        let mut top_p_values = vec![None];
1679        let mut top_k_values = vec![None];
1680        let mut repeat_penalty_values = vec![None];
1681        let mut flash_attn_values = vec![None];
1682        let mut threads_values = vec![None];
1683        let mut batch_size_values = vec![None];
1684        let mut expert_count_values = vec![None];
1685
1686        for p in &self.params_to_test {
1687            if !p.enabled {
1688                continue;
1689            }
1690
1691            let vals: Vec<f64> = {
1692                let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1693                (0..=step_count)
1694                    .map(|i| (p.min + (i as f64 * p.step)).min(p.max))
1695                    .collect()
1696            };
1697
1698            match p.name.as_str() {
1699                "temperature" => temp_values = vals.into_iter().map(Some).collect(),
1700                "top_p" => top_p_values = vals.into_iter().map(Some).collect(),
1701                "top_k" => top_k_values = vals.into_iter().map(|v| Some(v as i64)).collect(),
1702                "repeat_penalty" => repeat_penalty_values = vals.into_iter().map(Some).collect(),
1703                "flash_attn" => {
1704                    flash_attn_values = vals.into_iter().map(|v| Some(v >= 0.5)).collect()
1705                }
1706                "threads" => threads_values = vals.into_iter().map(|v| Some(v as u32)).collect(),
1707                "batch_size" => {
1708                    batch_size_values = vals.into_iter().map(|v| Some(v as u32)).collect()
1709                }
1710                "expert_count" => {
1711                    expert_count_values = vals.into_iter().map(|v| Some(v as i32)).collect()
1712                }
1713                _ => {}
1714            }
1715        }
1716
1717        let mut combinations = Vec::new();
1718        for &temp in &temp_values {
1719            for &top_p in &top_p_values {
1720                for &top_k in &top_k_values {
1721                    for &rp in &repeat_penalty_values {
1722                        for &fa in &flash_attn_values {
1723                            for &th in &threads_values {
1724                                for &bs in &batch_size_values {
1725                                    for &ec in &expert_count_values {
1726                                        combinations.push(BenchTuneParamValue {
1727                                            temperature: temp,
1728                                            top_p,
1729                                            top_k,
1730                                            repeat_penalty: rp,
1731                                            context_length: None,
1732                                            batch_size: bs,
1733                                            flash_attn: fa,
1734                                            threads: th,
1735                                            expert_count: ec,
1736                                            spec_type: None,
1737                                            draft_tokens: None,
1738                                        });
1739                                    }
1740                                }
1741                            }
1742                        }
1743                    }
1744                }
1745            }
1746        }
1747
1748        combinations
1749    }
1750
1751    /// Get total number of tests to run
1752    pub fn get_total_tests_count(&self) -> usize {
1753        self.generate_combinations().len()
1754    }
1755}
1756
1757// ── Parameter struct field count tests ──────────────────────────
1758
1759#[cfg(test)]
1760mod field_count_tests {
1761    use super::*;
1762
1763    /// Verify ModelSettings has the expected number of fields.
1764    /// If this test fails, a field was added/removed — update the checklist
1765    /// in src/models.rs:665 and all locations listed there.
1766    #[test]
1767    fn test_model_settings_field_count() {
1768    // This test uses reflection-like field access to verify the count.
1769    // If a field is added/removed, the tuple size changes and the
1770    // expected count assertion fails.
1771    let s = ModelSettings::default();
1772    let field_count = count_model_settings_fields(&s);
1773    assert_eq!(
1774        field_count, 81,
1775        "ModelSettings has {} fields (expected 81). \
1776         Update the checklist at src/models.rs:665 and all locations listed there.",
1777        field_count
1778    );
1779}
1780
1781/// Count fields in ModelSettings by forcing reference to each one.
1782/// This is a compile-time guarantee: if a field is removed, the
1783/// function won't compile. If a field is added, the count changes.
1784#[allow(clippy::too_many_lines)]
1785fn count_model_settings_fields(s: &ModelSettings) -> usize {
1786    let _ = (
1787        &s.context_length,
1788        &s.threads,
1789        &s.threads_batch,
1790        &s.batch_size,
1791        &s.ubatch_size,
1792        &s.parallel,
1793        &s.max_concurrent_predictions,
1794        &s.uniform_cache,
1795        &s.kv_cache_offload,
1796        &s.cache_type_k,
1797        &s.cache_type_v,
1798        &s.keep,
1799        &s.swa_full,
1800        &s.mlock,
1801        &s.mmap,
1802        &s.numa,
1803        &s.system_prompt,
1804        &s.system_prompt_preset_name,
1805        &s.gpu_layers_mode,
1806        &s.split_mode,
1807        &s.tensor_split,
1808        &s.main_gpu,
1809        &s.fit,
1810        &s.lora,
1811        &s.lora_scaled,
1812        &s.rpc,
1813        &s.embedding,
1814        &s.flash_attn,
1815        &s.expert_count,
1816        &s.jinja,
1817        &s.chat_template,
1818        &s.chat_template_kwargs,
1819        &s.seed,
1820        &s.temperature,
1821        &s.top_k,
1822        &s.top_p,
1823        &s.min_p,
1824        &s.typical_p,
1825        &s.mirostat,
1826        &s.mirostat_lr,
1827        &s.mirostat_ent,
1828        &s.ignore_eos,
1829        &s.samplers,
1830        &s.repeat_penalty,
1831        &s.repeat_last_n,
1832        &s.presence_penalty,
1833        &s.frequency_penalty,
1834        &s.dry_multiplier,
1835        &s.dry_base,
1836        &s.dry_allowed_length,
1837        &s.dry_penalty_last_n,
1838        &s.rope_scaling,
1839        &s.rope_scale,
1840        &s.rope_freq_base,
1841        &s.rope_freq_scale,
1842        &s.rope_yarn_enabled,
1843        &s.host,
1844        &s.port,
1845        &s.timeout,
1846        &s.cache_prompt,
1847        &s.cache_reuse,
1848        &s.webui,
1849        &s.max_tokens,
1850        &s.cache_type,
1851        &s.backend,
1852        &s.llama_cpp_version_cpu,
1853        &s.llama_cpp_version_vulkan,
1854        &s.llama_cpp_version_rocm,
1855        &s.llama_cpp_version_rocm_lemonade,
1856        &s.llama_cpp_version_cuda,
1857        &s.api_endpoint_enabled,
1858        &s.api_endpoint_port,
1859        &s.spec_type,
1860        &s.draft_tokens,
1861        &s.tags,
1862        &s.ws_server_enabled,
1863        &s.ws_server_port,
1864        &s.ws_server_auth_key,
1865        &s.ws_server_tls_enabled,
1866        &s.ws_server_tls_cert,
1867        &s.ws_server_tls_key,
1868    );
1869    81
1870}
1871
1872/// Verify that is_dirty() uses derived PartialEq (compiler-enforced).
1873/// This test confirms that two identical settings are not dirty,
1874/// and two different settings are dirty.
1875#[test]
1876fn test_is_dirty_uses_derived_eq() {
1877    let s1 = ModelSettings::default();
1878    let s2 = ModelSettings::default();
1879    let s3 = s1.clone();
1880
1881    // Identical settings should not be dirty
1882    assert!(!s1.is_dirty(&s2));
1883    assert!(!s1.is_dirty(&s3));
1884
1885    // Derived PartialEq must match is_dirty
1886    assert_eq!(s1 != s2, s1.is_dirty(&s2));
1887    assert_eq!(s1 != s3, s1.is_dirty(&s3));
1888}
1889
1890  /// Verify DefaultParams and ModelSettings share the same field set
1891    /// via the From<DefaultParams> for ModelSettings implementation.
1892    #[test]
1893    fn test_from_default_params_completeness() {
1894        let dp = crate::config::DefaultParams::default();
1895        let ms: ModelSettings = dp.clone().into();
1896
1897        // Verify the From impl produces a valid ModelSettings
1898        // The derived PartialEq ensures all fields were mapped
1899        assert_eq!(ms.context_length, 131072);
1900        assert_eq!(ms.threads, dp.threads);
1901        assert_eq!(ms.temperature, 0.8);
1902        // backend is hardware-dependent in DefaultParams::default(),
1903        // so we just verify it was mapped correctly
1904        assert_eq!(ms.backend, dp.backend);
1905    }
1906    }