Skip to main content

llm_manager/
models.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6/// The state of a model in the manager.
7#[derive(Debug, Clone, PartialEq)]
8pub enum ModelState {
9    Available,
10    Loading,
11    Benchmarking,
12    Loaded { port: u16, pid: u32 },
13    Failed { error: String },
14}
15
16/// Sort order for search results.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum SearchSort {
19    Relevance,
20    Downloads,
21    Likes,
22    Trending,
23    CreatedAt,
24}
25
26impl SearchSort {
27    pub fn next(self) -> Self {
28        match self {
29            SearchSort::Relevance => SearchSort::Downloads,
30            SearchSort::Downloads => SearchSort::Likes,
31            SearchSort::Likes => SearchSort::Trending,
32            SearchSort::Trending => SearchSort::CreatedAt,
33            SearchSort::CreatedAt => SearchSort::Relevance,
34        }
35    }
36
37    pub fn label(self) -> &'static str {
38        match self {
39            SearchSort::Relevance => "Relevance",
40            SearchSort::Downloads => "Downloads",
41            SearchSort::Likes => "Likes",
42            SearchSort::Trending => "Trending",
43            SearchSort::CreatedAt => "Created",
44        }
45    }
46}
47
48/// A model found via HuggingFace search.
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct SearchResult {
51    pub model_id: String,
52    pub model_name: String,
53    pub tags: Vec<String>,
54    pub downloads: u64,
55    pub likes: u64,
56    pub pipeline_tag: Option<String>,
57    pub size: Option<u64>,
58    pub parameters: Option<String>,
59    pub capabilities: Vec<String>,
60    pub context_length: Option<u32>,
61    pub readme: Option<String>,
62    /// Quantization type extracted from GGUF metadata (e.g. "Q4_K_M", "Q8_0").
63    pub quantization: Option<String>,
64    /// License extracted from tags (e.g. "apache-2.0", "llama3.1").
65    pub license: Option<String>,
66    /// HuggingFace trending score.
67    pub trending_score: i64,
68    /// Creation timestamp string.
69    pub created_at: Option<String>,
70    /// Whether a matching GGUF file is already downloaded locally.
71    #[serde(default)]
72    pub downloaded: bool,
73}
74
75/// Download progress information.
76#[derive(Debug, Clone)]
77pub struct DownloadState {
78    pub model_id: String,
79    pub filename: String,
80    pub total_bytes: u64,
81    pub downloaded_bytes: u64,
82    pub status: DownloadStatus,
83    pub cancelled: bool,
84    pub cancel_token: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
85    /// Download control: 1=downloading, 2=paused, 3=cancelled
86    pub download_state: u8,
87    /// Shared atomic state for pausing/resuming the download loop
88    pub download_state_arc: Option<std::sync::Arc<std::sync::atomic::AtomicU8>>,
89    pub start_time: std::time::Instant,
90    pub bytes_per_second: f64,
91    /// Filesystem path where the download is being saved.
92    pub dest: Option<std::path::PathBuf>,
93}
94
95impl DownloadState {
96    pub fn new(model_id: String, filename: String, total_bytes: u64) -> Self {
97        Self {
98            model_id,
99            filename,
100            total_bytes,
101            downloaded_bytes: 0,
102            status: DownloadStatus::Downloading,
103            cancelled: false,
104            cancel_token: None,
105            download_state: 1,
106            download_state_arc: None,
107            start_time: std::time::Instant::now(),
108            bytes_per_second: 0.0,
109            dest: None,
110        }
111    }
112}
113
114impl ModelSettings {
115    /// Get the version string for the currently active backend.
116    pub fn get_active_backend_version(&self) -> Option<&String> {
117        match self.backend {
118            Backend::Cpu => self.llama_cpp_version_cpu.as_ref(),
119            Backend::Vulkan => self.llama_cpp_version_vulkan.as_ref(),
120            Backend::Rocm => self.llama_cpp_version_rocm.as_ref(),
121            Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade.as_ref(),
122            Backend::Cuda => self.llama_cpp_version_cuda.as_ref(),
123            _ => None,
124        }
125    }
126
127    /// Get the display version string for the currently active backend (defaults to "latest").
128    pub fn get_active_backend_version_display(&self) -> &str {
129        self.get_active_backend_version()
130            .map(|s| s.as_str())
131            .unwrap_or("latest")
132    }
133
134    /// Set the version string for the currently active backend.
135    pub fn set_active_backend_version(&mut self, tag: Option<String>) {
136        match self.backend {
137            Backend::Cpu => self.llama_cpp_version_cpu = tag,
138            Backend::Vulkan => self.llama_cpp_version_vulkan = tag,
139            Backend::Rocm => self.llama_cpp_version_rocm = tag,
140            Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade = tag,
141            Backend::Cuda => self.llama_cpp_version_cuda = tag,
142            _ => {}
143        }
144    }
145}
146
147/// Strip the .gguf extension from a model name.
148pub fn strip_gguf(name: &str) -> &str {
149    name.strip_suffix(".gguf")
150        .or_else(|| name.strip_suffix(".GGUF"))
151        .unwrap_or(name)
152}
153
154/// Ensure host string is valid for URL construction and CLI arguments.
155/// Handles empty strings (defaults to 127.0.0.1), strips display suffixes,
156/// and wraps IPv6 addresses in brackets.
157pub fn clean_host(host: &str) -> String {
158    let host = host.trim();
159    if host.is_empty() {
160        return "127.0.0.1".to_string();
161    }
162    // Remove (xxx) suffixes often used in display, e.g. "localhost (127.0.0.1)"
163    let host = host.split_whitespace().next().unwrap_or(host);
164    if host.contains(':') && !host.starts_with('[') {
165        format!("[{}]", host)
166    } else {
167        host.to_string()
168    }
169}
170
171/// Format a host string for display (e.g. "" or "127.0.0.1" -> "localhost (127.0.0.1)").
172pub fn format_host(host: &str) -> &str {
173    match host {
174        "" | "127.0.0.1" => "localhost (127.0.0.1)",
175        _ => host,
176    }
177}
178
179impl From<crate::config::DefaultParams> for ModelSettings {
180    fn from(dp: crate::config::DefaultParams) -> Self {
181        Self {
182            context_length: dp.context_length,
183            threads: dp.threads,
184            threads_batch: dp.threads_batch,
185            batch_size: dp.batch_size,
186            ubatch_size: dp.ubatch_size,
187            parallel: dp.parallel,
188            max_concurrent_predictions: dp.max_concurrent_predictions,
189            uniform_cache: dp.uniform_cache,
190            kv_cache_offload: dp.kv_cache_offload,
191            cache_type_k: dp.cache_type_k,
192            cache_type_v: dp.cache_type_v,
193            keep: dp.keep,
194            swa_full: dp.swa_full,
195            mlock: dp.mlock,
196            mmap: dp.mmap,
197            numa: dp.numa,
198            system_prompt: dp.system_prompt,
199            system_prompt_preset_name: dp.system_prompt_preset_name,
200
201            gpu_layers_mode: match dp.gpu_layers {
202                n if n < 0 => GpuLayersMode::All,
203                _ => dp.gpu_layers_mode,
204            },
205            split_mode: dp.split_mode,
206            tensor_split: dp.tensor_split,
207            main_gpu: dp.main_gpu,
208            fit: dp.fit,
209            lora: dp.lora,
210            lora_scaled: dp.lora_scaled,
211            rpc: dp.rpc,
212            embedding: dp.embedding,
213            flash_attn: dp.flash_attn,
214            expert_count: dp.expert_count,
215            jinja: dp.jinja,
216            chat_template: dp.chat_template,
217            chat_template_kwargs: dp.chat_template_kwargs,
218            seed: dp.seed,
219            temperature: dp.temperature,
220            top_k: dp.top_k,
221            top_p: dp.top_p,
222            min_p: dp.min_p,
223            typical_p: dp.typical_p,
224            mirostat: dp.mirostat,
225            mirostat_lr: dp.mirostat_lr,
226            mirostat_ent: dp.mirostat_ent,
227            ignore_eos: dp.ignore_eos,
228            samplers: dp.samplers,
229            repeat_penalty: dp.repeat_penalty,
230            repeat_last_n: dp.repeat_last_n,
231            presence_penalty: dp.presence_penalty,
232            frequency_penalty: dp.frequency_penalty,
233            dry_multiplier: dp.dry_multiplier,
234            dry_base: dp.dry_base,
235            dry_allowed_length: dp.dry_allowed_length,
236            dry_penalty_last_n: dp.dry_penalty_last_n,
237            rope_scaling: dp.rope_scaling,
238            rope_scale: dp.rope_scale,
239            rope_freq_base: dp.rope_freq_base,
240            rope_freq_scale: dp.rope_freq_scale,
241            rope_yarn_enabled: dp.rope_yarn_enabled,
242            host: dp.host,
243            port: dp.port,
244            timeout: dp.timeout,
245            cache_prompt: dp.cache_prompt,
246            cache_reuse: dp.cache_reuse,
247            webui: dp.webui,
248            max_tokens: dp.max_tokens,
249            cache_type: dp.cache_type,
250            backend: dp.backend,
251            llama_cpp_version_cpu: dp.llama_cpp_version_cpu,
252            llama_cpp_version_vulkan: dp.llama_cpp_version_vulkan,
253            llama_cpp_version_rocm: dp.llama_cpp_version_rocm,
254            llama_cpp_version_rocm_lemonade: dp.llama_cpp_version_rocm_lemonade,
255            llama_cpp_version_cuda: dp.llama_cpp_version_cuda,
256            api_endpoint_enabled: dp.api_endpoint_enabled,
257            api_endpoint_port: dp.api_endpoint_port,
258 
259            spec_type: dp.spec_type,
260            draft_tokens: dp.draft_tokens,
261            tags: dp.tags,
262        }
263    }
264}
265
266/// How to handle GPU layer offloading.
267#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Hash)]
268#[derive(Default)]
269pub enum GpuLayersMode {
270    #[default]
271    Auto,
272    Specific(u32),
273    All,
274}
275
276
277#[derive(Debug, Clone, PartialEq, Eq)]
278pub enum DownloadStatus {
279    Downloading,
280    Paused,
281    Complete,
282    Error(String),
283    Cancelled,
284}
285
286// ── Cache type enums ──────────────────────────────────────────
287
288/// Main KV cache data type.
289#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
290pub enum CacheType {
291    #[serde(rename = "f16")]
292    #[default]
293    F16,
294    #[serde(rename = "bf16")]
295    BF16,
296    #[serde(rename = "fq8_0")]
297    Fq8_0,
298    #[serde(rename = "fq4_1")]
299    Fq4_1,
300}
301
302impl std::fmt::Display for CacheType {
303    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        match self {
305            CacheType::F16 => write!(f, "f16"),
306            CacheType::BF16 => write!(f, "bf16"),
307            CacheType::Fq8_0 => write!(f, "fq8_0"),
308            CacheType::Fq4_1 => write!(f, "fq4_1"),
309        }
310    }
311}
312
313/// KV cache quantization type.
314#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default, Hash)]
315pub enum CacheQuantType {
316    #[serde(rename = "f32")]
317    F32,
318    #[serde(rename = "f16")]
319    #[default]
320    F16,
321    #[serde(rename = "bf16")]
322    BF16,
323    #[serde(rename = "q8_0")]
324    Q8_0,
325    #[serde(rename = "q4_0")]
326    Q4_0,
327    #[serde(rename = "q4_1")]
328    Q4_1,
329    #[serde(rename = "iq4_nl")]
330    Iq4Nl,
331    #[serde(rename = "q5_0")]
332    Q5_0,
333    #[serde(rename = "q5_1")]
334    Q5_1,
335}
336
337pub type CacheTypeK = CacheQuantType;
338pub type CacheTypeV = CacheQuantType;
339
340impl CacheQuantType {
341    pub fn from_u8(n: u8) -> Self {
342        match n {
343            0 => Self::F32,
344            1 => Self::F16,
345            2 => Self::BF16,
346            3 => Self::Q8_0,
347            4 => Self::Q5_1,
348            5 => Self::Q5_0,
349            6 => Self::Q4_1,
350            7 => Self::Q4_0,
351            8 => Self::Iq4Nl,
352            _ => Self::F16,
353        }
354    }
355    pub fn next(&self) -> Self {
356        match self {
357            Self::F32 => Self::F16,
358            Self::F16 => Self::BF16,
359            Self::BF16 => Self::Q8_0,
360            Self::Q8_0 => Self::Q5_1,
361            Self::Q5_1 => Self::Q5_0,
362            Self::Q5_0 => Self::Q4_1,
363            Self::Q4_1 => Self::Q4_0,
364            Self::Q4_0 => Self::Iq4Nl,
365            Self::Iq4Nl => Self::F32,
366        }
367    }
368    pub fn prev(&self) -> Self {
369        match self {
370            Self::F32 => Self::Iq4Nl,
371            Self::F16 => Self::F32,
372            Self::BF16 => Self::F16,
373            Self::Q8_0 => Self::BF16,
374            Self::Q5_1 => Self::Q8_0,
375            Self::Q5_0 => Self::Q5_1,
376            Self::Q4_1 => Self::Q5_0,
377            Self::Q4_0 => Self::Q4_1,
378            Self::Iq4Nl => Self::Q4_0,
379        }
380    }
381}
382
383impl std::fmt::Display for CacheQuantType {
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        match self {
386            Self::F32 => write!(f, "f32"),
387            Self::F16 => write!(f, "f16"),
388            Self::BF16 => write!(f, "bf16"),
389            Self::Q8_0 => write!(f, "q8_0"),
390            Self::Q4_0 => write!(f, "q4_0"),
391            Self::Q4_1 => write!(f, "q4_1"),
392            Self::Iq4Nl => write!(f, "iq4_nl"),
393            Self::Q5_0 => write!(f, "q5_0"),
394            Self::Q5_1 => write!(f, "q5_1"),
395        }
396    }
397}
398
399impl From<&str> for CacheQuantType {
400    fn from(s: &str) -> Self {
401        match s {
402            "F32" => Self::F32,
403            "F16" => Self::F16,
404            "BF16" => Self::BF16,
405            "Q8_0" => Self::Q8_0,
406            "Q4_0" => Self::Q4_0,
407            "Q4_1" => Self::Q4_1,
408            "Iq4Nl" => Self::Iq4Nl,
409            "Q5_0" => Self::Q5_0,
410            "Q5_1" => Self::Q5_1,
411            _ => Self::F16, // Default or error handling
412        }
413    }
414}
415
416/// Split mode for multi-GPU.
417#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
418pub enum SplitMode {
419    #[serde(rename = "none")]
420    None,
421    #[serde(rename = "layer")]
422    #[default]
423    Layer,
424    #[serde(rename = "row")]
425    Row,
426    #[serde(rename = "tensor")]
427    Tensor,
428}
429
430impl std::fmt::Display for SplitMode {
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        match self {
433            SplitMode::None => write!(f, "none"),
434            SplitMode::Layer => write!(f, "layer"),
435            SplitMode::Row => write!(f, "row"),
436            SplitMode::Tensor => write!(f, "tensor"),
437        }
438    }
439}
440
441/// NUMA optimization mode.
442#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
443pub enum NumMode {
444    #[serde(rename = "none")]
445    #[default]
446    None,
447    #[serde(rename = "distribute")]
448    Distribute,
449    #[serde(rename = "isolate")]
450    Isolate,
451    #[serde(rename = "numactl")]
452    Numactl,
453}
454
455impl std::fmt::Display for NumMode {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        match self {
458            NumMode::None => write!(f, "none"),
459            NumMode::Distribute => write!(f, "distribute"),
460            NumMode::Isolate => write!(f, "isolate"),
461            NumMode::Numactl => write!(f, "numactl"),
462        }
463    }
464}
465
466/// RoPE frequency scaling method.
467#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
468pub enum RopeScaling {
469    #[serde(rename = "none")]
470    #[default]
471    None,
472    #[serde(rename = "linear")]
473    Linear,
474    #[serde(rename = "yarn")]
475    Yarn,
476}
477
478impl std::fmt::Display for RopeScaling {
479    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480        match self {
481            RopeScaling::None => write!(f, "none"),
482            RopeScaling::Linear => write!(f, "linear"),
483            RopeScaling::Yarn => write!(f, "yarn"),
484        }
485    }
486}
487
488/// Mirostat version.
489#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
490pub enum Mirostat {
491    #[serde(rename = "0")]
492    #[default]
493    Off,
494    #[serde(rename = "1")]
495    V1,
496    #[serde(rename = "2")]
497    Mirostat2,
498}
499
500impl std::fmt::Display for Mirostat {
501    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
502        match self {
503            Mirostat::Off => write!(f, "off"),
504            Mirostat::V1 => write!(f, "1"),
505            Mirostat::Mirostat2 => write!(f, "2"),
506        }
507    }
508}
509
510/// Sampler order string (semicolon-separated).
511/// Common types: penalties, dry, top_n_sigma, top_k, typ_p, top_p, min_p, xtc, temperature
512#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
513pub struct Samplers(pub String);
514
515impl Default for Samplers {
516    fn default() -> Self {
517        Self("penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature".to_string())
518    }
519}
520
521impl std::fmt::Display for Samplers {
522    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
523        write!(f, "{}", self.0)
524    }
525}
526
527/// Backend used to run the llama.cpp server.
528#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
529pub enum Backend {
530    #[serde(rename = "cpu")]
531    #[default]
532    Cpu,
533    #[serde(rename = "vulkan")]
534    Vulkan,
535    #[serde(rename = "rocm")]
536    Rocm,
537    #[serde(rename = "rocm_lemonade")]
538    RocmLemonade,
539    #[serde(rename = "cuda")]
540    Cuda,
541    #[serde(rename = "cpu_arm64")]
542    CpuArm64,
543    #[serde(rename = "win_cpu")]
544    CpuWindows,
545    #[serde(rename = "win_vulkan")]
546    VulkanWindows,
547    #[serde(rename = "win_cuda_12_4")]
548    CudaWindows12_4,
549    #[serde(rename = "win_cuda_13_1")]
550    CudaWindows13_1,
551    #[serde(rename = "win_hip")]
552    HipWindows,
553    #[serde(rename = "macos_arm64")]
554    CpuMacosArm64,
555    #[serde(rename = "macos_x64")]
556    CpuMacosX64,
557}
558
559impl Backend {
560    /// Get the identifier used for directory names and asset prefixes.
561    pub fn slug(&self) -> &'static str {
562        match self {
563            Backend::Cpu => "cpu",
564            Backend::Vulkan => "vulkan",
565            Backend::Rocm => "rocm",
566            Backend::RocmLemonade => "rocm-lemonade",
567            Backend::Cuda => "cuda",
568            Backend::CpuArm64 => "cpu-arm64",
569            Backend::CpuWindows => "win-cpu",
570            Backend::VulkanWindows => "win-vulkan",
571            Backend::CudaWindows12_4 => "win-cuda-12.4",
572            Backend::CudaWindows13_1 => "win-cuda-13.1",
573            Backend::HipWindows => "win-hip",
574            Backend::CpuMacosArm64 => "macos-arm64",
575            Backend::CpuMacosX64 => "macos-x64",
576        }
577    }
578
579    /// Returns true if this backend is for Linux.
580    pub fn is_linux(self) -> bool {
581        matches!(
582            self,
583            Backend::Cpu
584                | Backend::Vulkan
585                | Backend::Rocm
586                | Backend::RocmLemonade
587                | Backend::Cuda
588                | Backend::CpuArm64
589        )
590    }
591
592    /// Returns true if this backend is for Windows.
593    pub fn is_windows(self) -> bool {
594        matches!(
595            self,
596            Backend::CpuWindows
597                | Backend::VulkanWindows
598                | Backend::CudaWindows12_4
599                | Backend::CudaWindows13_1
600                | Backend::HipWindows
601        )
602    }
603
604    /// Returns true if this backend is for macOS.
605    pub fn is_macos(self) -> bool {
606        matches!(self, Backend::CpuMacosArm64 | Backend::CpuMacosX64)
607    }
608
609    /// Parse backend from string representation.
610    pub fn from_str(s: &str) -> Self {
611        let s = s.to_lowercase();
612        if s.starts_with("vulkan") || s.starts_with("vk") {
613            Backend::Vulkan
614        } else if s.starts_with("rocm") || s.starts_with("ro") {
615            if s.contains("lemonade") {
616                Backend::RocmLemonade
617            } else {
618                Backend::Rocm
619            }
620        } else if s.starts_with("cuda") || s.starts_with("cu") {
621            Backend::Cuda
622        } else {
623            Backend::Cpu // Default
624        }
625    }
626}
627
628impl std::fmt::Display for Backend {
629    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
630        write!(f, "{}", self.slug())
631    }
632}
633
634/// Server mode: normal (single model) or router (multiple models).
635#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
636pub enum ServerMode {
637    #[serde(rename = "normal")]
638    #[default]
639    Normal,
640    #[serde(rename = "router")]
641    Router,
642    #[serde(rename = "bench_gpu", alias = "bench")]
643    Bench,
644    #[serde(rename = "bench_tune")]
645    BenchTune,
646}
647
648impl std::fmt::Display for ServerMode {
649    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
650        match self {
651            ServerMode::Normal => write!(f, "Normal"),
652            ServerMode::Router => write!(f, "Router (XP!)"),
653            ServerMode::Bench => write!(f, "Bench GPU"),
654            ServerMode::BenchTune => write!(f, "BenchTune"),
655        }
656    }
657}
658
659// ── ModelSettings ─────────────────────────────────────────────
660//
661// WHEN ADDING A NEW PARAMETER to ModelSettings, update ALL of these locations:
662//   1.  src/models.rs:668       — ModelSettings struct field + doc comment
663//   2.  src/models.rs:179       — From<DefaultParams> for ModelSettings (field mapping)
664//   3.  src/config.rs:564       — DefaultParams struct field + serde attribute
665//   4.  src/config.rs:785       — DefaultParams Default impl (default value)
666//   5.  src/config.rs:175       — ModelOverride struct field (Option<T>)
667//   6.  src/config.rs:299       — ModelOverride::from_settings() (Some(s.field))
668//   7.  src/config.rs:385       — ModelOverride::apply() (one of the 3 macro calls)
669//   8.  src/tui/settings.rs:312 — all_fields() SettingField entry (id, name, section, etc.)
670//   9.  src/tui/settings.rs:1149— profile_settings_parts() diff macro call
671//  10.  src/tui/app/profiles.rs:80 — settings_fingerprint() hash call
672//  11.  src/tui/event/helpers.rs:125 — sync_global_settings() (if global-scoped)
673//  12.  src/tui/event/panel/settings.rs — key handlers for numeric/toggle edit
674//
675// The derived PartialEq on ModelSettings and DefaultParams guarantees
676// is_dirty() compares ALL fields — no field can be missed.
677// The build.rs script checks field counts at compile time.
678
679/// Settings for loading a model via llama.cpp server.
680#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
681pub struct ModelSettings {
682    // ── Loading ──────────────────────────────────────────────
683    /// Size of the prompt context.
684    pub context_length: u32,
685    /// Number of CPU threads for generation.
686    pub threads: u32,
687    /// Number of CPU threads for batch processing.
688    pub threads_batch: u32,
689    /// Logical maximum batch size.
690    pub batch_size: u32,
691    /// Physical maximum batch size (micro-batch).
692    pub ubatch_size: u32,
693    /// Max concurrent predictions (sequences).
694    pub parallel: u32,
695    /// Max concurrent predictions (requests in flight). None means no --parallel argument.
696    pub max_concurrent_predictions: Option<u32>,
697    /// Use uniform (unified) KV cache across all sequences.
698    pub uniform_cache: bool,
699    /// Offload KV cache to system RAM.
700    pub kv_cache_offload: bool,
701    /// KV cache data type for K.
702    pub cache_type_k: Option<CacheTypeK>,
703    /// KV cache data type for V.
704    pub cache_type_v: Option<CacheTypeV>,
705    /// Keep N tokens from the initial prompt.
706    pub keep: i32,
707    /// Use full-size SWA cache.
708    pub swa_full: bool,
709    /// Force system to keep model in RAM.
710    pub mlock: bool,
711    /// Memory-map the model.
712    pub mmap: bool,
713    /// NUMA optimization.
714    pub numa: NumMode,
715    /// System prompt.
716    pub system_prompt: String,
717    /// Name of the system prompt preset currently selected.
718    pub system_prompt_preset_name: String,
719
720    // ── GPU ──────────────────────────────────────────────────
721    /// GPU layer offloading mode.
722    pub gpu_layers_mode: GpuLayersMode,
723    /// Split mode across multiple GPUs.
724    pub split_mode: SplitMode,
725    /// Fraction of model offloaded to each GPU (comma-separated).
726    pub tensor_split: String,
727    /// Main GPU index.
728    pub main_gpu: i32,
729    /// Whether to adjust arguments to fit device memory.
730    pub fit: bool,
731    /// Path to LoRA adapter.
732    pub lora: Option<PathBuf>,
733    /// Path to LoRA adapter with scale.
734    pub lora_scaled: Option<(PathBuf, f32)>,
735    /// RPC servers.
736    pub rpc: String,
737    /// Restrict to embedding use case.
738    pub embedding: bool,
739    /// Enable Flash Attention.
740    pub flash_attn: bool,
741    /// Active experts per token (MoE models, -1 = model default).
742    pub expert_count: i32,
743    /// Use Jinja template engine for chat.
744    pub jinja: bool,
745    /// Custom chat template string.
746    pub chat_template: Option<String>,
747    /// JSON string for --chat-template-kwargs (e.g. {"enable_thinking": false}).
748    pub chat_template_kwargs: Option<String>,
749
750    // ── Sampling ─────────────────────────────────────────────
751    /// RNG seed (-1 = random).
752    pub seed: i32,
753    /// Temperature.
754    pub temperature: f32,
755    /// Top-k sampling (0 = disabled).
756    pub top_k: i32,
757    /// Top-p sampling (1.0 = disabled).
758    pub top_p: f32,
759    /// Minimum probability for a token.
760    pub min_p: f32,
761    /// Locally typical sampling parameter p.
762    pub typical_p: f32,
763    /// Mirostat version (0=off, 1=Mirostat, 2=Mirostat2).
764    pub mirostat: Mirostat,
765    /// Mirostat learning rate (eta).
766    pub mirostat_lr: f32,
767    /// Mirostat target entropy (tau).
768    pub mirostat_ent: f32,
769    /// Ignore end-of-stream token.
770    pub ignore_eos: bool,
771    /// Sampler order string.
772    pub samplers: Samplers,
773
774    // ── Repetition Control ───────────────────────────────────
775    /// Penalize repeat sequence of tokens.
776    pub repeat_penalty: f32,
777    /// Last N tokens to consider for repeat penalty.
778    pub repeat_last_n: i32,
779    /// Repeat alpha presence penalty.
780    pub presence_penalty: Option<f32>,
781    /// Repeat alpha frequency penalty.
782    pub frequency_penalty: Option<f32>,
783    /// DRY sampling multiplier.
784    pub dry_multiplier: f32,
785    /// DRY sampling base value.
786    pub dry_base: f32,
787    /// DRY allowed length.
788    pub dry_allowed_length: i32,
789    /// DRY penalty last N.
790    pub dry_penalty_last_n: i32,
791
792    // ── RoPE ─────────────────────────────────────────────────
793    /// RoPE frequency scaling method.
794    pub rope_scaling: RopeScaling,
795    /// RoPE context scaling factor.
796    pub rope_scale: f32,
797    /// RoPE base frequency.
798    pub rope_freq_base: f32,
799    /// RoPE frequency scaling factor.
800    pub rope_freq_scale: f32,
801    /// Enable Yarn RoPE scaling mode.
802    pub rope_yarn_enabled: bool,
803
804    // ── Server ───────────────────────────────────────────────
805    /// Host address.
806    pub host: String,
807    /// Port.
808    pub port: u16,
809    /// Server timeout in seconds.
810    pub timeout: u32,
811    /// Whether to enable prompt caching.
812    pub cache_prompt: bool,
813    /// Min chunk size for cache reuse.
814    pub cache_reuse: u32,
815    /// Whether to enable WebUI.
816    pub webui: bool,
817
818    // ── Other ────────────────────────────────────────────────
819    /// Max tokens to predict.
820    pub max_tokens: Option<u32>,
821    /// Cache type (legacy, kept for compatibility).
822    pub cache_type: CacheType,
823    /// Backend (cpu/vulkan).
824    pub backend: Backend,
825    /// llama.cpp release tag for CPU backend (e.g. "b1234" or None for latest).
826    pub llama_cpp_version_cpu: Option<String>,
827    /// llama.cpp release tag for Vulkan backend (e.g. "b1234" or None for latest).
828    pub llama_cpp_version_vulkan: Option<String>,
829    /// llama.cpp release tag for ROCm backend (e.g. "b1234" or None for latest).
830    pub llama_cpp_version_rocm: Option<String>,
831    /// Lemonade llama.cpp release tag for ROCm backend.
832    pub llama_cpp_version_rocm_lemonade: Option<String>,
833    /// llama.cpp release tag for CUDA backend.
834    pub llama_cpp_version_cuda: Option<String>,
835    /// Whether to enable the API proxy server.
836    pub api_endpoint_enabled: bool,
837    /// Port for the API proxy server.
838    pub api_endpoint_port: u16,
839    /// Speculative decoding type (e.g., "draft-mtp", "ngram-simple", "" for off).
840    pub spec_type: String,
841    /// Number of draft tokens for MTP.
842    pub draft_tokens: u32,
843    /// Tags for the model.
844    pub tags: Vec<String>,
845 }
846
847impl Default for ModelSettings {
848    fn default() -> Self {
849        let mut s: Self = crate::config::DefaultParams::default().into();
850        // Override fields that differ from DefaultParams defaults
851        s.uniform_cache = false;
852        s.cache_type_k = Some(CacheTypeK::F16);
853        s.cache_type_v = Some(CacheTypeV::F16);
854        s.cache_type = CacheType::default();
855        s.backend = Backend::Cpu;
856        s.presence_penalty = Some(0.0);
857        s.frequency_penalty = Some(0.0);
858        s
859    }
860}
861
862impl ModelSettings {
863    /// Create ModelSettings from config defaults, applying model-specific overrides.
864    pub fn from_config(config: &crate::config::Config) -> Self {
865        config.default.clone().into()
866    }
867}
868
869/// Default benchmark prompt used when starting a tuning session.
870pub const BENCHMARK_PROMPT: &str = "Create Mona Lisa image in ascii art using text, number, symbol, everything possible. this should be the perfect painting.";
871
872/// A discovered model file.
873#[derive(Debug, Clone)]
874pub struct DiscoveredModel {
875    pub path: PathBuf,
876    pub name: String,
877    pub file_size: u64,
878    pub display_name: String, // path relative to model_dir for display
879}
880
881/// Parsed GGUF metadata for a model, cached to avoid re-parsing the file.
882#[derive(Debug, Clone, Default)]
883pub struct GgufMetadata {
884    pub layers: u32,
885    pub hidden_size: u32,
886    pub n_ctx_train: u32,
887    pub n_head: u32,
888    pub n_kv_head: u32,
889    pub arch: String,
890    pub file_type: String,
891    pub quantization: String,
892    pub model_parameters: String,
893    pub domain: String,
894    pub capabilities: Vec<String>,
895    pub tokenizer: String,
896    pub vocab_size: u32,
897    pub draft_tokens: u32,
898}
899
900impl GgufMetadata {
901    pub fn from_path(path: &std::path::Path) -> anyhow::Result<Self> {
902        let path_str = path.to_string_lossy();
903        let mut container = gguf_rs::get_gguf_container(&path_str)
904            .map_err(|e| anyhow::anyhow!("Failed to get GGUF container: {}", e))?;
905        let model_data = container
906            .decode()
907            .map_err(|e| anyhow::anyhow!("Failed to decode GGUF: {}", e))?;
908
909        let mut meta = Self::default();
910
911        let extract_str = |key: &str| -> String {
912            model_data
913                .metadata()
914                .get(key)
915                .and_then(|v| v.as_str().map(|s| s.to_string()))
916                .unwrap_or_default()
917        };
918
919        let extract_num = |key: &str| -> Option<u64> {
920            model_data.metadata().get(key).and_then(|v| {
921                v.as_u64()
922                    .or_else(|| v.as_i64().map(|x| x as u64))
923                    .or_else(|| v.as_f64().map(|x| x as u64))
924            })
925        };
926
927        meta.arch = extract_str("general.architecture");
928        let prefix = if meta.arch.is_empty() {
929            "llama"
930        } else {
931            &meta.arch
932        };
933
934        let get_num_with_fallback = |suffix: &str| -> u32 {
935            extract_num(&format!("{}.{}", prefix, suffix))
936                .or_else(|| {
937                    if prefix != "llama" {
938                        extract_num(&format!("llama.{}", suffix))
939                    } else {
940                        None
941                    }
942                })
943                .unwrap_or(0) as u32
944        };
945
946        meta.layers = get_num_with_fallback("block_count");
947        meta.hidden_size = get_num_with_fallback("embedding_length");
948        meta.n_ctx_train = get_num_with_fallback("context_length");
949        meta.n_head = get_num_with_fallback("attention.head_count");
950        meta.n_kv_head = get_num_with_fallback("attention.head_count_kv");
951
952        if let Some(value) = model_data.metadata().get("tokenizer.ggml.tokens")
953            && let Some(arr) = value.as_array()
954        {
955            meta.vocab_size = arr.len() as u32;
956        }
957
958        if meta.arch == "mtp" {
959            meta.draft_tokens = extract_num("mtp.draft_tokens").unwrap_or(0) as u32;
960        }
961
962        if let Some(v) = extract_num("general.file_type") {
963            meta.file_type = match v {
964                0 => "F32".to_string(),
965                1 => "F16".to_string(),
966                2 => "Q4_0".to_string(),
967                3 => "Q4_1".to_string(),
968                7 => "Q8_0".to_string(),
969                8 => "Q5_0".to_string(),
970                9 => "Q5_1".to_string(),
971                10 => "Q2_K".to_string(),
972                11 => "Q3_K_S".to_string(),
973                12 => "Q3_K_M".to_string(),
974                13 => "Q3_K_L".to_string(),
975                14 => "Q4_K_S".to_string(),
976                15 => "Q4_K_M".to_string(),
977                16 => "Q5_K_S".to_string(),
978                17 => "Q5_K_M".to_string(),
979                18 => "Q6_K".to_string(),
980                19 => "IQ2_XXS".to_string(),
981                20 => "IQ2_XS".to_string(),
982                21 => "IQ3_XXS".to_string(),
983                22 => "IQ1_S".to_string(),
984                23 => "IQ4_NL".to_string(),
985                24 => "IQ3_S".to_string(),
986                25 => "IQ2_S".to_string(),
987                26 => "IQ4_XS".to_string(),
988                _ => format!("Unknown ({})", v),
989            };
990        }
991
992        if let Some(value) = model_data.metadata().get("general.capabilities")
993            && let Some(arr) = value.as_array()
994        {
995            for v in arr {
996                if let Some(s) = v.as_str() {
997                    meta.capabilities.push(s.to_string());
998                }
999            }
1000        }
1001
1002        if model_data
1003            .metadata()
1004            .contains_key("tokenizer.chat_template")
1005        {
1006            meta.capabilities.push("chat".to_string());
1007        }
1008
1009        meta.tokenizer = extract_str("tokenizer.ggml.model");
1010        meta.domain = extract_str("general.domain");
1011        meta.model_parameters = model_data.model_parameters();
1012
1013        Ok(meta)
1014    }
1015}
1016
1017/// Metrics reported by the llama.cpp server.
1018#[derive(Debug, Clone)]
1019pub struct ServerMetrics {
1020    pub loaded: bool,
1021    pub tps: f64,
1022    pub prompt_tps: f64,
1023    pub cpu_usage: f64,
1024    pub gpu_mem_used: u64,
1025    pub gpu_mem_total: u64,
1026    pub ram_used: u64,
1027    pub ctx_used: u32,
1028    pub ctx_max: u32,
1029    /// Sum of gpu_mem_used across all loaded models (for Total VRAM display).
1030    pub total_vram_used: u64,
1031    /// Number of decoded tokens from print_timing logs.
1032    pub decoded_tokens: u64,
1033    /// Generation tokens per second parsed from llama.cpp log output (e.g., "tg = 64.45 t/s").
1034    pub gen_tps: f64,
1035    /// Estimated latency per generated token in milliseconds.
1036    pub latency_per_token_ms: f64,
1037    /// Estimated prompt processing latency in milliseconds (1000 / prompt_tps).
1038    pub prompt_latency_ms: f64,
1039}
1040
1041/// GPU device buffer reported by llama-server during model loading.
1042#[derive(Debug, Clone)]
1043pub struct GPUBuffer {
1044    pub device: String,
1045    pub buffer_size_mib: f64,
1046}
1047
1048/// Progress information during model loading, parsed from llama-server log output.
1049#[derive(Debug, Clone, Default)]
1050pub struct LoadProgress {
1051    /// Total number of layers in the model.
1052    pub layers_total: Option<u32>,
1053    /// Number of layers already offloaded to GPU.
1054    pub layers_loaded: Option<u32>,
1055    /// Total number of tensors in the model (from "Loading tensor X of Y" log).
1056    pub tensors_total: Option<u32>,
1057    /// Number of tensors loaded (counted from dot-lines in log).
1058    pub tensors_loaded: u32,
1059    /// GPU device buffers with their sizes.
1060    pub buffers: Vec<GPUBuffer>,
1061}
1062
1063impl Default for ServerMetrics {
1064    fn default() -> Self {
1065        Self {
1066            loaded: false,
1067            tps: 0.0,
1068            prompt_tps: 0.0,
1069            cpu_usage: 0.0,
1070            gpu_mem_used: 0,
1071            gpu_mem_total: 0,
1072            ram_used: 0,
1073            ctx_used: 0,
1074            ctx_max: 0,
1075            total_vram_used: 0,
1076            decoded_tokens: 0,
1077            gen_tps: 0.0,
1078            latency_per_token_ms: 0.0,
1079            prompt_latency_ms: 0.0,
1080        }
1081    }
1082}
1083
1084/// WebSocket-friendly metrics snapshot (serializable, no internal state).
1085#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1086pub struct WsMetrics {
1087    pub model_name: String,
1088    pub loaded: bool,
1089    pub state: String,
1090    pub tps: f64,
1091    pub prompt_tps: f64,
1092    pub ctx_used: u32,
1093    pub ctx_max: u32,
1094    pub cpu_usage: f64,
1095    pub gpu_mem_used: u64,
1096    pub gpu_mem_total: u64,
1097    pub ram_used: u64,
1098    pub latency_per_token_ms: f64,
1099    pub decoded_tokens: u64,
1100    pub gen_tps: f64,
1101    pub timestamp: u64,
1102    // Server command
1103    pub cmd_display: Option<String>,
1104    // LLM settings
1105    pub threads: u32,
1106    pub threads_batch: u32,
1107    pub context_length: u32,
1108    pub ubatch_size: u32,
1109    pub batch_size: u32,
1110    pub temperature: f32,
1111    pub top_k: u32,
1112    pub top_p: f32,
1113    pub min_p: f32,
1114    pub typical_p: f32,
1115    pub seed: i32,
1116    pub repeat_penalty: f32,
1117    pub repeat_last_n: i32,
1118    pub presence_penalty: Option<f32>,
1119    pub frequency_penalty: Option<f32>,
1120    pub mirostat: Option<u32>,
1121    pub mirostat_lr: Option<f32>,
1122    pub mirostat_ent: Option<f32>,
1123    pub max_tokens: Option<u32>,
1124    pub flash_attn: bool,
1125    pub kv_cache_offload: bool,
1126    pub cache_type_k: Option<String>,
1127    pub cache_type_v: Option<String>,
1128    pub uniform_cache: bool,
1129    pub mlock: bool,
1130    pub mmap: bool,
1131    pub embedding: bool,
1132    pub jinja: bool,
1133    pub ignore_eos: bool,
1134    pub samplers: String,
1135    pub expert_count: u32,
1136    pub gpu_layers: String,
1137    pub backend: String,
1138    pub llama_cpp_version: String,
1139    pub spec_type: String,
1140    pub draft_tokens: u32,
1141}
1142
1143impl WsMetrics {
1144    pub fn from_metrics(
1145        metrics: &ServerMetrics,
1146        model_name: &str,
1147        state: &str,
1148        settings: &crate::models::ModelSettings,
1149        cmd_display: Option<&str>,
1150    ) -> Self {
1151        use std::time::{SystemTime, UNIX_EPOCH};
1152        let timestamp = SystemTime::now()
1153            .duration_since(UNIX_EPOCH)
1154            .map(|d| d.as_secs())
1155            .unwrap_or(0);
1156        let gpu_layers = match settings.gpu_layers_mode {
1157            crate::models::GpuLayersMode::Auto => "Auto".to_string(),
1158            crate::models::GpuLayersMode::Specific(n) => n.to_string(),
1159            crate::models::GpuLayersMode::All => "All".to_string(),
1160        };
1161        Self {
1162            model_name: model_name.to_string(),
1163            loaded: metrics.loaded,
1164            state: state.to_string(),
1165            tps: metrics.tps,
1166            prompt_tps: metrics.prompt_tps,
1167            ctx_used: metrics.ctx_used,
1168            ctx_max: metrics.ctx_max,
1169            cpu_usage: metrics.cpu_usage,
1170            gpu_mem_used: metrics.gpu_mem_used,
1171            gpu_mem_total: metrics.gpu_mem_total,
1172            ram_used: metrics.ram_used,
1173            latency_per_token_ms: metrics.latency_per_token_ms,
1174            decoded_tokens: metrics.decoded_tokens,
1175            gen_tps: metrics.gen_tps,
1176            timestamp,
1177            cmd_display: cmd_display.map(String::from),
1178            threads: settings.threads,
1179            threads_batch: settings.threads_batch,
1180            context_length: settings.context_length,
1181            ubatch_size: settings.ubatch_size,
1182            batch_size: settings.batch_size,
1183            temperature: settings.temperature,
1184            top_k: settings.top_k as u32,
1185            top_p: settings.top_p,
1186            min_p: settings.min_p,
1187            typical_p: settings.typical_p,
1188            seed: settings.seed,
1189            repeat_penalty: settings.repeat_penalty,
1190            repeat_last_n: settings.repeat_last_n,
1191            presence_penalty: settings.presence_penalty,
1192            frequency_penalty: settings.frequency_penalty,
1193            mirostat: Some(match settings.mirostat {
1194                crate::models::Mirostat::Off => 0,
1195                crate::models::Mirostat::V1 => 1,
1196                crate::models::Mirostat::Mirostat2 => 2,
1197            }),
1198            mirostat_lr: Some(settings.mirostat_lr),
1199            mirostat_ent: Some(settings.mirostat_ent),
1200            max_tokens: settings.max_tokens,
1201            flash_attn: settings.flash_attn,
1202            kv_cache_offload: settings.kv_cache_offload,
1203            cache_type_k: settings.cache_type_k.map(|k| k.to_string()),
1204            cache_type_v: settings.cache_type_v.map(|k| k.to_string()),
1205            uniform_cache: settings.uniform_cache,
1206            mlock: settings.mlock,
1207            mmap: settings.mmap,
1208            embedding: settings.embedding,
1209            jinja: settings.jinja,
1210            ignore_eos: settings.ignore_eos,
1211            samplers: settings.samplers.to_string(),
1212            expert_count: settings.expert_count as u32,
1213            gpu_layers,
1214            backend: settings.backend.to_string(),
1215            llama_cpp_version: settings.get_active_backend_version_display().to_string(),
1216            spec_type: settings.spec_type.clone(),
1217            draft_tokens: settings.draft_tokens,
1218        }
1219    }
1220}
1221
1222/// Estimate VRAM usage (in MiB) for a model with the given settings.
1223///
1224/// Model file size is the size of the GGUF file in MiB. The model itself
1225/// takes 1x its size in VRAM (loaded as-is). KV cache is the dominant
1226/// variable cost — it scales with context_length, batch_size, and layers.
1227///
1228/// The KV cache formula accounts for:
1229/// - Actual GQA ratio from model metadata (n_kv_head / n_head)
1230/// - FlashAttention: reduces KV cache storage by ~2x
1231/// - Unified KV cache: shares KV across sequences, dividing by parallel count
1232/// - KV cache quantization (q4_0, q5_0, q8_0, etc.)
1233pub fn estimate_vram_mib(
1234    model_mib: u64,
1235    settings: &ModelSettings,
1236    total_layers: u32,
1237    hidden_size_opt: Option<u32>,
1238    n_head_opt: Option<u32>,
1239    n_kv_head_opt: Option<u32>,
1240    gpu_mem_total_mib: u64,
1241) -> u64 {
1242    let model_mib_f = model_mib as f64;
1243
1244    // Compute how much of the model is loaded into VRAM based on GPU layers.
1245    let gpu_layers = match settings.gpu_layers_mode {
1246        GpuLayersMode::Auto => {
1247            // Heuristic: ~60% of layers when Auto (llama.cpp will decide at runtime)
1248            if total_layers > 0 {
1249                (total_layers as f64 * 0.6) as u32
1250            } else {
1251                20
1252            }
1253        }
1254        GpuLayersMode::Specific(n) => {
1255            if total_layers > 0 {
1256                n.min(total_layers)
1257            } else {
1258                n
1259            }
1260        }
1261        GpuLayersMode::All => {
1262            if total_layers > 0 {
1263                total_layers
1264            } else {
1265                32
1266            }
1267        }
1268    };
1269
1270    // Model weights loaded into VRAM: proportional to GPU layers.
1271    let model_vram = if total_layers > 0 && gpu_layers > 0 {
1272        model_mib_f * (gpu_layers as f64 / total_layers as f64).min(1.0)
1273    } else if gpu_layers > 0 {
1274        model_mib_f
1275    } else {
1276        0.0
1277    };
1278
1279    if matches!(settings.gpu_layers_mode, GpuLayersMode::Specific(0)) {
1280        return 0; // CPU only
1281    }
1282
1283    // Heuristic for hidden_size if not provided:
1284    // A 7B model (4-bit) is ~4000 MiB and has hidden=4096.
1285    let hidden_size = match hidden_size_opt {
1286        Some(h) => h as f64,
1287        None => {
1288            let params_est = model_mib_f / 550.0;
1289            (1024.0 * params_est.sqrt().max(1.0) * 1.5).max(512.0)
1290        }
1291    };
1292
1293    // ── KV cache estimation ─────────────────────────────────────
1294
1295    // GQA ratio: real KV heads vs query heads.
1296    // If n_kv_head == n_head, ratio = 1.0 (no reduction).
1297    // If n_kv_head < n_head, ratio < 1.0 (KV cache is smaller).
1298    let gqa_ratio = match (n_head_opt, n_kv_head_opt) {
1299        (Some(n_head), Some(n_kv_head)) if n_head > 0 => n_kv_head as f64 / n_head as f64,
1300        _ => 1.0, // fallback: assume no GQA
1301    };
1302
1303    // FlashAttention reduces KV cache storage by ~2x because it doesn't
1304    // need to keep the full attention matrix in memory.
1305    let flash_attn_factor = if settings.flash_attn { 0.5 } else { 1.0 };
1306
1307    // Unified KV cache shares a single KV buffer across all sequences.
1308    let uniform_cache_factor = if settings.uniform_cache {
1309        1.0 / settings.parallel as f64
1310    } else {
1311        1.0
1312    };
1313
1314    // KV cache in MiB:
1315    // Formula: 2 * n_layer * n_ctx * n_embd_kv * sizeof(type)
1316    // n_embd_kv = hidden_size * gqa_ratio
1317    // The KV cache is allocated for the total number of model layers,
1318    // not just the number layers loaded into the GPU (gpu_layers).
1319    // However only gpu_layers * sizeof(type) contributes to the VRAM cost.
1320    let kv_mib = (2.0
1321        * hidden_size
1322        * settings.context_length as f64
1323        * total_layers as f64
1324        * gqa_ratio
1325        * gpu_layers as f64
1326        / total_layers as f64  // VRAM cost: only GPU-loaded portion of KV cache
1327        * flash_attn_factor
1328        * uniform_cache_factor
1329        * kv_quant_bytes(
1330            settings.cache_type_k.unwrap_or(CacheTypeK::F16),
1331            settings.cache_type_v.unwrap_or(CacheTypeV::F16)
1332        ))
1333        / (1024.0 * 1024.0);
1334
1335    // Activation overhead during inference (proportional to batch * hidden).
1336    // Increased multiplier to 8.0 (from 2.0) to be more pessimistic about scratch buffers.
1337    let activation_mib = (settings.batch_size as f64 * hidden_size * 8.0) / (1024.0 * 1024.0);
1338
1339    // Fixed overhead for driver, fragmentation, and small meta buffers.
1340    // Use 3.8% of max VRAM, falling back to 500MiB if unknown.
1341    let fixed_overhead = if gpu_mem_total_mib > 0 {
1342        gpu_mem_total_mib as f64 * 0.038
1343    } else {
1344        500.0
1345    };
1346
1347    let total_mib = model_vram + kv_mib + activation_mib + fixed_overhead + 550.0;
1348
1349    total_mib.ceil() as u64
1350}
1351
1352/// Return the average KV cache element size in bytes for the given K/V types.
1353///
1354/// KV cache stores K and V separately, potentially at different precisions.
1355/// We average the two to get a single per-element size.
1356fn kv_quant_bytes(k_type: CacheQuantType, v_type: CacheQuantType) -> f64 {
1357    let get_bytes = |t: CacheQuantType| match t {
1358        CacheQuantType::F32 => 4.0,
1359        CacheQuantType::F16 | CacheQuantType::BF16 => 2.0,
1360        CacheQuantType::Q8_0 => 1.0,
1361        CacheQuantType::Q5_0 | CacheQuantType::Q5_1 => 0.625, // 5 bits
1362        CacheQuantType::Q4_0 | CacheQuantType::Q4_1 | CacheQuantType::Iq4Nl => 0.5, // 4 bits
1363    };
1364    (get_bytes(k_type) + get_bytes(v_type)) / 2.0
1365}
1366
1367impl ModelSettings {
1368    /// Check if this settings differs from `other` in any field.
1369    /// Uses derived PartialEq which compares all fields — compiler-enforced.
1370    pub fn is_dirty(&self, other: &Self) -> bool {
1371        self != other
1372    }
1373}
1374
1375// Benchmark Tuning types
1376#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1377pub struct BenchTuneConfig {
1378    pub model_path: PathBuf,
1379    pub num_iterations: u32,
1380    pub prompt: String,
1381    pub params_to_test: Vec<BenchTuneParam>,
1382    pub test_duration: Duration,
1383    pub bench_mode: BenchTuneMode,
1384    pub n_predict: u32,
1385    pub chat_template_kwargs: Option<String>,
1386    pub test_timeout: Duration,
1387}
1388
1389#[derive(Debug, Clone, Serialize, Deserialize)]
1390pub struct BenchTuneParam {
1391    pub name: String,
1392    pub min: f64,
1393    pub max: f64,
1394    pub step: f64,
1395    pub enabled: bool,
1396    pub variants: Vec<String>,
1397}
1398
1399impl PartialEq for BenchTuneParam {
1400    fn eq(&self, other: &Self) -> bool {
1401        self.name == other.name
1402            && self.min.to_bits() == other.min.to_bits()
1403            && self.max.to_bits() == other.max.to_bits()
1404            && self.step.to_bits() == other.step.to_bits()
1405            && self.enabled == other.enabled
1406            && self.variants == other.variants
1407    }
1408}
1409impl Eq for BenchTuneParam {}
1410
1411#[derive(Debug, Clone, Serialize, Deserialize)]
1412pub struct BenchTuneParamValue {
1413    pub temperature: Option<f64>,
1414    pub top_p: Option<f64>,
1415    pub top_k: Option<i64>,
1416    pub repeat_penalty: Option<f64>,
1417    pub context_length: Option<u32>,
1418    pub batch_size: Option<u32>,
1419    pub flash_attn: Option<bool>,
1420    pub threads: Option<u32>,
1421    pub expert_count: Option<i32>,
1422    pub spec_type: Option<String>,
1423    pub draft_tokens: Option<u32>,
1424}
1425
1426impl PartialEq for BenchTuneParamValue {
1427    fn eq(&self, other: &Self) -> bool {
1428        self.temperature.map(|v| v.to_bits()) == other.temperature.map(|v| v.to_bits())
1429            && self.top_p.map(|v| v.to_bits()) == other.top_p.map(|v| v.to_bits())
1430            && self.top_k == other.top_k
1431            && self.repeat_penalty.map(|v| v.to_bits()) == other.repeat_penalty.map(|v| v.to_bits())
1432            && self.context_length == other.context_length
1433            && self.batch_size == other.batch_size
1434            && self.flash_attn == other.flash_attn
1435            && self.threads == other.threads
1436            && self.expert_count == other.expert_count
1437            && self.spec_type == other.spec_type
1438            && self.draft_tokens == other.draft_tokens
1439    }
1440}
1441impl Eq for BenchTuneParamValue {}
1442
1443#[derive(Debug, Clone, Serialize, Deserialize)]
1444pub struct BenchTuneResult {
1445    pub params: BenchTuneParamValue,
1446    pub metrics: BenchTuneMetrics,
1447    pub outputs: Vec<String>,
1448    pub per_iteration_metrics: Vec<BenchTuneMetrics>,
1449    pub base_settings: Option<ModelSettings>,
1450    pub server_command: Option<String>,
1451}
1452
1453#[derive(Debug, Clone, Serialize, Deserialize)]
1454pub struct BenchTuneMetrics {
1455    pub prompt_tps: f64,
1456    pub generation_tps: f64,
1457    pub combined_tps: f64,
1458    pub latency_per_token: f64,
1459    pub first_token_time: f64,
1460}
1461
1462#[derive(Debug, Clone, Serialize, Deserialize)]
1463pub enum BenchTuneStatus {
1464    Running {
1465        current: usize,
1466        total: usize,
1467        progress: f32,
1468        current_params: BenchTuneParamValue,
1469    },
1470    Completed {
1471        total_tests: usize,
1472        successful_tests: usize,
1473        elapsed: Duration,
1474    },
1475    PartiallyCompleted {
1476        total_tests: usize,
1477        successful_tests: usize,
1478        failed_tests: usize,
1479        elapsed: Duration,
1480    },
1481    Cancelled {
1482        total_tests: usize,
1483        successful_tests: usize,
1484        failed_tests: usize,
1485        elapsed: Duration,
1486    },
1487    Error {
1488        error: String,
1489    },
1490}
1491
1492#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1493#[derive(Default)]
1494pub enum BenchTuneMode {
1495    /// Runtime-only mode: sends all params in /completion request body, no server restarts
1496    RuntimeOnly,
1497    /// Full mode: spawns a new server for each parameter combination (tests server-level params)
1498    #[default]
1499    Full,
1500}
1501
1502
1503/// Progress status for benchmark tuning
1504#[derive(Debug, Clone)]
1505pub enum BenchTuneProgress {
1506    /// Tuning is running.
1507    Running {
1508        current: usize,
1509        total: usize,
1510        progress: f32,
1511        current_params: BenchTuneParamValue,
1512    },
1513    /// Tuning is complete.
1514    Completed {
1515        total_tests: usize,
1516        successful_tests: usize,
1517        elapsed: Duration,
1518    },
1519    /// Tuning completed with some failures.
1520    PartiallyCompleted {
1521        total_tests: usize,
1522        successful_tests: usize,
1523        failed_tests: usize,
1524        elapsed: Duration,
1525    },
1526    /// Tuning was cancelled by the user.
1527    Cancelled {
1528        total_tests: usize,
1529        successful_tests: usize,
1530        failed_tests: usize,
1531        elapsed: Duration,
1532    },
1533    /// Tuning failed.
1534    Error { error: String },
1535}
1536
1537impl BenchTuneProgress {
1538    pub fn from_status(status: &BenchTuneStatus) -> Option<Self> {
1539        match status {
1540            BenchTuneStatus::Running {
1541                current,
1542                total,
1543                progress,
1544                current_params,
1545            } => Some(BenchTuneProgress::Running {
1546                current: *current,
1547                total: *total,
1548                progress: *progress,
1549                current_params: current_params.clone(),
1550            }),
1551            BenchTuneStatus::Completed {
1552                total_tests,
1553                successful_tests,
1554                elapsed,
1555            } => Some(BenchTuneProgress::Completed {
1556                total_tests: *total_tests,
1557                successful_tests: *successful_tests,
1558                elapsed: *elapsed,
1559            }),
1560            BenchTuneStatus::PartiallyCompleted {
1561                total_tests,
1562                successful_tests,
1563                failed_tests,
1564                elapsed,
1565            } => Some(BenchTuneProgress::PartiallyCompleted {
1566                total_tests: *total_tests,
1567                successful_tests: *successful_tests,
1568                failed_tests: *failed_tests,
1569                elapsed: *elapsed,
1570            }),
1571            BenchTuneStatus::Cancelled {
1572                total_tests,
1573                successful_tests,
1574                failed_tests,
1575                elapsed,
1576            } => Some(BenchTuneProgress::Cancelled {
1577                total_tests: *total_tests,
1578                successful_tests: *successful_tests,
1579                failed_tests: *failed_tests,
1580                elapsed: *elapsed,
1581            }),
1582            BenchTuneStatus::Error { error } => Some(BenchTuneProgress::Error {
1583                error: error.clone(),
1584            }),
1585        }
1586    }
1587}
1588
1589impl BenchTuneConfig {
1590    pub fn new(model_path: PathBuf, num_iterations: u32, prompt: String) -> Self {
1591        Self {
1592            model_path,
1593            num_iterations,
1594            prompt,
1595            params_to_test: vec![
1596                BenchTuneParam {
1597                    name: "temperature".to_string(),
1598                    min: 0.4,
1599                    max: 1.0,
1600                    step: 0.1,
1601                    enabled: false,
1602                    variants: vec![],
1603                },
1604                BenchTuneParam {
1605                    name: "top_p".to_string(),
1606                    min: 0.8,
1607                    max: 1.0,
1608                    step: 0.1,
1609                    enabled: false,
1610                    variants: vec![],
1611                },
1612                BenchTuneParam {
1613                    name: "top_k".to_string(),
1614                    min: 10.0,
1615                    max: 40.0,
1616                    step: 5.0,
1617                    enabled: false,
1618                    variants: vec![],
1619                },
1620                BenchTuneParam {
1621                    name: "repeat_penalty".to_string(),
1622                    min: 1.0,
1623                    max: 1.5,
1624                    step: 0.1,
1625                    enabled: false,
1626                    variants: vec![],
1627                },
1628                BenchTuneParam {
1629                    name: "flash_attn".to_string(),
1630                    min: 0.0,
1631                    max: 1.0,
1632                    step: 1.0,
1633                    enabled: false,
1634                    variants: vec![],
1635                },
1636                BenchTuneParam {
1637                    name: "threads".to_string(),
1638                    min: 4.0,
1639                    max: 16.0,
1640                    step: 4.0,
1641                    enabled: false,
1642                    variants: vec![],
1643                },
1644                BenchTuneParam {
1645                    name: "batch_size".to_string(),
1646                    min: 512.0,
1647                    max: 2048.0,
1648                    step: 512.0,
1649                    enabled: false,
1650                    variants: vec![],
1651                },
1652                BenchTuneParam {
1653                    name: "expert_count".to_string(),
1654                    min: -1.0,
1655                    max: 4.0,
1656                    step: 1.0,
1657                    enabled: false,
1658                    variants: vec![],
1659                },
1660                BenchTuneParam {
1661                    name: "spec_type".to_string(),
1662                    min: 0.0,
1663                    max: 8.0,
1664                    step: 1.0,
1665                    enabled: false,
1666                    variants: vec![
1667                        "Off".to_string(),
1668                        "draft-mtp".to_string(),
1669                        "draft-simple".to_string(),
1670                        "draft-eagle3".to_string(),
1671                        "ngram-simple".to_string(),
1672                        "ngram-map-k".to_string(),
1673                        "ngram-map-k4v".to_string(),
1674                        "ngram-mod".to_string(),
1675                        "ngram-cache".to_string(),
1676                    ],
1677                },
1678     BenchTuneParam {
1679                    name: "draft_tokens".to_string(),
1680                    min: 0.0,
1681                    max: 8.0,
1682                    step: 1.0,
1683                    enabled: false,
1684                    variants: vec![],
1685                },
1686            ],
1687            test_duration: Duration::from_secs(30),
1688            bench_mode: BenchTuneMode::default(),
1689            n_predict: 512,
1690            chat_template_kwargs: Some(r#"{"enable_thinking": false}"#.to_string()),
1691            test_timeout: Duration::from_secs(60),
1692        }
1693    }
1694
1695    /// Generate all parameter combinations based on the config
1696    pub fn generate_combinations(&self) -> Vec<BenchTuneParamValue> {
1697        let mut temp_values = vec![None];
1698        let mut top_p_values = vec![None];
1699        let mut top_k_values = vec![None];
1700        let mut repeat_penalty_values = vec![None];
1701        let mut flash_attn_values = vec![None];
1702        let mut threads_values = vec![None];
1703        let mut batch_size_values = vec![None];
1704        let mut expert_count_values = vec![None];
1705        let mut selected_spec_type = None;
1706        if let Some(p) = self.params_to_test.iter().find(|p| p.name == "spec_type") {
1707            let base_idx = (p.min as usize).min(p.variants.len().saturating_sub(1));
1708            let val = &p.variants[base_idx];
1709            if val != "Off" {
1710                selected_spec_type = Some(val.clone());
1711            }
1712        }
1713        let mut spec_type_values = vec![selected_spec_type];
1714        let mut draft_tokens_values = vec![None];
1715
1716        let spec_type_options = vec![
1717            "Off".to_string(),
1718            "draft-mtp".to_string(),
1719            "draft-simple".to_string(),
1720            "draft-eagle3".to_string(),
1721            "ngram-simple".to_string(),
1722            "ngram-map-k".to_string(),
1723            "ngram-map-k4v".to_string(),
1724            "ngram-mod".to_string(),
1725            "ngram-cache".to_string(),
1726        ];
1727
1728        for p in &self.params_to_test {
1729            if !p.enabled {
1730                continue;
1731            }
1732
1733            let vals: Vec<f64> = {
1734                let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1735                (0..=step_count)
1736                    .map(|i| (p.min + (i as f64 * p.step)).min(p.max))
1737                    .collect()
1738            };
1739
1740            match p.name.as_str() {
1741                "temperature" => temp_values = vals.into_iter().map(Some).collect(),
1742                "top_p" => top_p_values = vals.into_iter().map(Some).collect(),
1743                "top_k" => top_k_values = vals.into_iter().map(|v| Some(v as i64)).collect(),
1744                "repeat_penalty" => repeat_penalty_values = vals.into_iter().map(Some).collect(),
1745                "flash_attn" => {
1746                    flash_attn_values = vals.into_iter().map(|v| Some(v >= 0.5)).collect()
1747                }
1748                "threads" => threads_values = vals.into_iter().map(|v| Some(v as u32)).collect(),
1749                "batch_size" => {
1750                    batch_size_values = vals.into_iter().map(|v| Some(v as u32)).collect()
1751                }
1752                "expert_count" => {
1753                    expert_count_values = vals.into_iter().map(|v| Some(v as i32)).collect()
1754                }
1755                "spec_type" => {
1756                    if !p.variants.is_empty() {
1757                        spec_type_values = p.variants.clone().into_iter().map(Some).collect();
1758                    } else {
1759                        let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1760                        spec_type_values = (0..=step_count)
1761                            .map(|i| {
1762                                let idx = i.min(spec_type_options.len() - 1);
1763                                Some(spec_type_options[idx].clone())
1764                            })
1765                            .collect()
1766                    }
1767                }
1768                "draft_tokens" => {
1769                    draft_tokens_values = vals.into_iter().map(|v| Some(v as u32)).collect()
1770                }
1771                _ => {}
1772            }
1773        }
1774
1775        let mut combinations = Vec::new();
1776        for &temp in &temp_values {
1777            for &top_p in &top_p_values {
1778                for &top_k in &top_k_values {
1779                    for &rp in &repeat_penalty_values {
1780                        for &fa in &flash_attn_values {
1781                            for &th in &threads_values {
1782                                for &bs in &batch_size_values {
1783                                    for &ec in &expert_count_values {
1784                                        for st in &spec_type_values {
1785                                            for &dt in &draft_tokens_values {
1786                                                combinations.push(BenchTuneParamValue {
1787                                                    temperature: temp,
1788                                                    top_p,
1789                                                    top_k,
1790                                                    repeat_penalty: rp,
1791                                                    context_length: None,
1792                                                    batch_size: bs,
1793                                                    flash_attn: fa,
1794                                                    threads: th,
1795                                                    expert_count: ec,
1796                                                    spec_type: st.clone(),
1797                                                    draft_tokens: dt,
1798                                                });
1799                                            }
1800                                        }
1801                                    }
1802                                }
1803                            }
1804                        }
1805                    }
1806                }
1807            }
1808        }
1809
1810        combinations
1811    }
1812
1813    /// Get total number of tests to run
1814    pub fn get_total_tests_count(&self) -> usize {
1815        self.generate_combinations().len()
1816    }
1817
1818    /// Get number of parameter combinations (fast, without generating them)
1819    pub fn get_num_combinations(&self) -> usize {
1820        let mut count: u64 = 1;
1821        for p in &self.params_to_test {
1822            if !p.enabled {
1823                continue;
1824            }
1825            let vals = if p.name == "flash_attn" {
1826                2 // On/Off
1827            } else if !p.variants.is_empty() {
1828                p.variants.len()
1829            } else if p.name == "spec_type" {
1830                let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1831                (step_count + 1).min(9)
1832            } else {
1833                let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1834                step_count + 1
1835            };
1836            count *= vals as u64;
1837        }
1838        count as usize
1839    }
1840}
1841
1842// ── Parameter struct field count tests ──────────────────────────
1843
1844#[cfg(test)]
1845mod field_count_tests {
1846    use super::*;
1847
1848    /// Verify ModelSettings has the expected number of fields.
1849    /// If this test fails, a field was added/removed — update the checklist
1850    /// in src/models.rs:665 and all locations listed there.
1851    #[test]
1852    fn test_model_settings_field_count() {
1853    // This test uses reflection-like field access to verify the count.
1854    // If a field is added/removed, the tuple size changes and the
1855    // expected count assertion fails.
1856    let s = ModelSettings::default();
1857    let field_count = count_model_settings_fields(&s);
1858    assert_eq!(
1859        field_count, 75,
1860        "ModelSettings has {} fields (expected 75). \
1861         Update the checklist at src/models.rs:665 and all locations listed there.",
1862        field_count
1863    );
1864}
1865
1866/// Count fields in ModelSettings by forcing reference to each one.
1867/// This is a compile-time guarantee: if a field is removed, the
1868/// function won't compile. If a field is added, the count changes.
1869#[allow(clippy::too_many_lines)]
1870fn count_model_settings_fields(s: &ModelSettings) -> usize {
1871    let _ = (
1872        &s.context_length,
1873        &s.threads,
1874        &s.threads_batch,
1875        &s.batch_size,
1876        &s.ubatch_size,
1877        &s.parallel,
1878        &s.max_concurrent_predictions,
1879        &s.uniform_cache,
1880        &s.kv_cache_offload,
1881        &s.cache_type_k,
1882        &s.cache_type_v,
1883        &s.keep,
1884        &s.swa_full,
1885        &s.mlock,
1886        &s.mmap,
1887        &s.numa,
1888        &s.system_prompt,
1889        &s.system_prompt_preset_name,
1890        &s.gpu_layers_mode,
1891        &s.split_mode,
1892        &s.tensor_split,
1893        &s.main_gpu,
1894        &s.fit,
1895        &s.lora,
1896        &s.lora_scaled,
1897        &s.rpc,
1898        &s.embedding,
1899        &s.flash_attn,
1900        &s.expert_count,
1901        &s.jinja,
1902        &s.chat_template,
1903        &s.chat_template_kwargs,
1904        &s.seed,
1905        &s.temperature,
1906        &s.top_k,
1907        &s.top_p,
1908        &s.min_p,
1909        &s.typical_p,
1910        &s.mirostat,
1911        &s.mirostat_lr,
1912        &s.mirostat_ent,
1913        &s.ignore_eos,
1914        &s.samplers,
1915        &s.repeat_penalty,
1916        &s.repeat_last_n,
1917        &s.presence_penalty,
1918        &s.frequency_penalty,
1919        &s.dry_multiplier,
1920        &s.dry_base,
1921        &s.dry_allowed_length,
1922        &s.dry_penalty_last_n,
1923        &s.rope_scaling,
1924        &s.rope_scale,
1925        &s.rope_freq_base,
1926        &s.rope_freq_scale,
1927        &s.rope_yarn_enabled,
1928        &s.host,
1929        &s.port,
1930        &s.timeout,
1931        &s.cache_prompt,
1932        &s.cache_reuse,
1933        &s.webui,
1934        &s.max_tokens,
1935        &s.cache_type,
1936        &s.backend,
1937        &s.llama_cpp_version_cpu,
1938        &s.llama_cpp_version_vulkan,
1939        &s.llama_cpp_version_rocm,
1940        &s.llama_cpp_version_rocm_lemonade,
1941        &s.llama_cpp_version_cuda,
1942        &s.api_endpoint_enabled,
1943        &s.api_endpoint_port,
1944        &s.spec_type,
1945        &s.draft_tokens,
1946        &s.tags,
1947    );
1948    75
1949}
1950
1951/// Verify that is_dirty() uses derived PartialEq (compiler-enforced).
1952/// This test confirms that two identical settings are not dirty,
1953/// and two different settings are dirty.
1954#[test]
1955fn test_is_dirty_uses_derived_eq() {
1956    let s1 = ModelSettings::default();
1957    let s2 = ModelSettings::default();
1958    let s3 = s1.clone();
1959
1960    // Identical settings should not be dirty
1961    assert!(!s1.is_dirty(&s2));
1962    assert!(!s1.is_dirty(&s3));
1963
1964    // Derived PartialEq must match is_dirty
1965    assert_eq!(s1 != s2, s1.is_dirty(&s2));
1966    assert_eq!(s1 != s3, s1.is_dirty(&s3));
1967}
1968
1969  /// Verify DefaultParams and ModelSettings share the same field set
1970    /// via the From<DefaultParams> for ModelSettings implementation.
1971    #[test]
1972    fn test_from_default_params_completeness() {
1973        let dp = crate::config::DefaultParams::default();
1974        let ms: ModelSettings = dp.clone().into();
1975
1976        // Verify the From impl produces a valid ModelSettings
1977        // The derived PartialEq ensures all fields were mapped
1978        assert_eq!(ms.context_length, 131072);
1979        assert_eq!(ms.threads, dp.threads);
1980        assert_eq!(ms.temperature, 0.8);
1981        // backend is hardware-dependent in DefaultParams::default(),
1982        // so we just verify it was mapped correctly
1983        assert_eq!(ms.backend, dp.backend);
1984    }
1985    }