Skip to main content

llm_manager/
models.rs

1use std::path::PathBuf;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6/// The state of a model in the manager.
7#[derive(Debug, Clone, PartialEq)]
8pub enum ModelState {
9    Available,
10    Loading,
11    Benchmarking,
12    Loaded { port: u16, pid: u32 },
13    Failed { error: String },
14}
15
16/// Sort order for search results.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum SearchSort {
19    Relevance,
20    Downloads,
21    Likes,
22    Trending,
23    CreatedAt,
24}
25
26impl SearchSort {
27    pub fn next(self) -> Self {
28        match self {
29            SearchSort::Relevance => SearchSort::Downloads,
30            SearchSort::Downloads => SearchSort::Likes,
31            SearchSort::Likes => SearchSort::Trending,
32            SearchSort::Trending => SearchSort::CreatedAt,
33            SearchSort::CreatedAt => SearchSort::Relevance,
34        }
35    }
36
37    pub fn label(self) -> &'static str {
38        match self {
39            SearchSort::Relevance => "Relevance",
40            SearchSort::Downloads => "Downloads",
41            SearchSort::Likes => "Likes",
42            SearchSort::Trending => "Trending",
43            SearchSort::CreatedAt => "Created",
44        }
45    }
46}
47
48/// A model found via HuggingFace search.
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct SearchResult {
51    pub model_id: String,
52    pub model_name: String,
53    pub tags: Vec<String>,
54    pub downloads: u64,
55    pub likes: u64,
56    pub pipeline_tag: Option<String>,
57    pub size: Option<u64>,
58    pub parameters: Option<String>,
59    pub capabilities: Vec<String>,
60    pub context_length: Option<u32>,
61    pub readme: Option<String>,
62    /// Quantization type extracted from GGUF metadata (e.g. "Q4_K_M", "Q8_0").
63    pub quantization: Option<String>,
64    /// License extracted from tags (e.g. "apache-2.0", "llama3.1").
65    pub license: Option<String>,
66    /// HuggingFace trending score.
67    pub trending_score: i64,
68    /// Creation timestamp string.
69    pub created_at: Option<String>,
70    /// Whether a matching GGUF file is already downloaded locally.
71    #[serde(default)]
72    pub downloaded: bool,
73}
74
75/// Download progress information.
76#[derive(Debug, Clone)]
77pub struct DownloadState {
78    pub model_id: String,
79    pub filename: String,
80    pub total_bytes: u64,
81    pub downloaded_bytes: u64,
82    pub status: DownloadStatus,
83    pub cancelled: bool,
84    pub cancel_token: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
85    /// Download control: 1=downloading, 2=paused, 3=cancelled
86    pub download_state: u8,
87    /// Shared atomic state for pausing/resuming the download loop
88    pub download_state_arc: Option<std::sync::Arc<std::sync::atomic::AtomicU8>>,
89    pub start_time: std::time::Instant,
90    pub bytes_per_second: f64,
91    /// Filesystem path where the download is being saved.
92    pub dest: Option<std::path::PathBuf>,
93}
94
95impl DownloadState {
96    pub fn new(model_id: String, filename: String, total_bytes: u64) -> Self {
97        Self {
98            model_id,
99            filename,
100            total_bytes,
101            downloaded_bytes: 0,
102            status: DownloadStatus::Downloading,
103            cancelled: false,
104            cancel_token: None,
105            download_state: 1,
106            download_state_arc: None,
107            start_time: std::time::Instant::now(),
108            bytes_per_second: 0.0,
109            dest: None,
110        }
111    }
112}
113
114impl ModelSettings {
115    /// Get the version string for the currently active backend.
116    pub fn get_active_backend_version(&self) -> Option<&String> {
117        match self.backend {
118            Backend::Cpu => self.llama_cpp_version_cpu.as_ref(),
119            Backend::Vulkan => self.llama_cpp_version_vulkan.as_ref(),
120            Backend::Rocm => self.llama_cpp_version_rocm.as_ref(),
121            Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade.as_ref(),
122            Backend::Cuda => self.llama_cpp_version_cuda.as_ref(),
123            _ => None,
124        }
125    }
126
127    /// Get the display version string for the currently active backend (defaults to "latest").
128    pub fn get_active_backend_version_display(&self) -> &str {
129        self.get_active_backend_version()
130            .map(|s| s.as_str())
131            .unwrap_or("latest")
132    }
133
134    /// Set the version string for the currently active backend.
135    pub fn set_active_backend_version(&mut self, tag: Option<String>) {
136        match self.backend {
137            Backend::Cpu => self.llama_cpp_version_cpu = tag,
138            Backend::Vulkan => self.llama_cpp_version_vulkan = tag,
139            Backend::Rocm => self.llama_cpp_version_rocm = tag,
140            Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade = tag,
141            Backend::Cuda => self.llama_cpp_version_cuda = tag,
142            _ => {}
143        }
144    }
145}
146
147/// Strip the .gguf extension from a model name.
148pub fn strip_gguf(name: &str) -> &str {
149    name.strip_suffix(".gguf")
150        .or_else(|| name.strip_suffix(".GGUF"))
151        .unwrap_or(name)
152}
153
154/// Ensure host string is valid for URL construction and CLI arguments.
155/// Handles empty strings (defaults to 127.0.0.1), strips display suffixes,
156/// and wraps IPv6 addresses in brackets.
157pub fn clean_host(host: &str) -> String {
158    let host = host.trim();
159    if host.is_empty() {
160        return "127.0.0.1".to_string();
161    }
162    // Remove (xxx) suffixes often used in display, e.g. "localhost (127.0.0.1)"
163    let host = host.split_whitespace().next().unwrap_or(host);
164    if host.contains(':') && !host.starts_with('[') {
165        format!("[{}]", host)
166    } else {
167        host.to_string()
168    }
169}
170
171/// Format a host string for display (e.g. "" or "127.0.0.1" -> "localhost (127.0.0.1)").
172pub fn format_host(host: &str) -> &str {
173    match host {
174        "" | "127.0.0.1" => "localhost (127.0.0.1)",
175        _ => host,
176    }
177}
178
179impl From<crate::config::DefaultParams> for ModelSettings {
180    fn from(dp: crate::config::DefaultParams) -> Self {
181        Self {
182            context_length: dp.context_length,
183            threads: dp.threads,
184            threads_batch: dp.threads_batch,
185            batch_size: dp.batch_size,
186            ubatch_size: dp.ubatch_size,
187            parallel: dp.parallel,
188            max_concurrent_predictions: dp.max_concurrent_predictions,
189            uniform_cache: dp.uniform_cache,
190            kv_cache_offload: dp.kv_cache_offload,
191            cache_type_k: dp.cache_type_k,
192            cache_type_v: dp.cache_type_v,
193            keep: dp.keep,
194            swa_full: dp.swa_full,
195            mlock: dp.mlock,
196            mmap: dp.mmap,
197            numa: dp.numa,
198            system_prompt: dp.system_prompt,
199            system_prompt_preset_name: dp.system_prompt_preset_name,
200
201            gpu_layers_mode: match dp.gpu_layers {
202                n if n < 0 => GpuLayersMode::All,
203                _ => dp.gpu_layers_mode,
204            },
205            split_mode: dp.split_mode,
206            tensor_split: dp.tensor_split,
207            main_gpu: dp.main_gpu,
208            fit: dp.fit,
209            lora: dp.lora,
210            lora_scaled: dp.lora_scaled,
211            rpc: dp.rpc,
212            embedding: dp.embedding,
213            flash_attn: dp.flash_attn,
214            expert_count: dp.expert_count,
215            jinja: dp.jinja,
216            chat_template: dp.chat_template,
217            chat_template_kwargs: dp.chat_template_kwargs,
218            seed: dp.seed,
219            temperature: dp.temperature,
220            top_k: dp.top_k,
221            top_p: dp.top_p,
222            min_p: dp.min_p,
223            typical_p: dp.typical_p,
224            mirostat: dp.mirostat,
225            mirostat_lr: dp.mirostat_lr,
226            mirostat_ent: dp.mirostat_ent,
227            ignore_eos: dp.ignore_eos,
228            samplers: dp.samplers,
229            repeat_penalty: dp.repeat_penalty,
230            repeat_last_n: dp.repeat_last_n,
231            presence_penalty: dp.presence_penalty,
232            frequency_penalty: dp.frequency_penalty,
233            dry_multiplier: dp.dry_multiplier,
234            dry_base: dp.dry_base,
235            dry_allowed_length: dp.dry_allowed_length,
236            dry_penalty_last_n: dp.dry_penalty_last_n,
237            rope_scaling: dp.rope_scaling,
238            rope_scale: dp.rope_scale,
239            rope_freq_base: dp.rope_freq_base,
240            rope_freq_scale: dp.rope_freq_scale,
241            rope_yarn_enabled: dp.rope_yarn_enabled,
242            host: dp.host,
243            port: dp.port,
244            timeout: dp.timeout,
245            cache_prompt: dp.cache_prompt,
246            cache_reuse: dp.cache_reuse,
247            webui: dp.webui,
248            max_tokens: dp.max_tokens,
249            cache_type: dp.cache_type,
250            backend: dp.backend,
251            llama_cpp_version_cpu: dp.llama_cpp_version_cpu,
252            llama_cpp_version_vulkan: dp.llama_cpp_version_vulkan,
253            llama_cpp_version_rocm: dp.llama_cpp_version_rocm,
254            llama_cpp_version_rocm_lemonade: dp.llama_cpp_version_rocm_lemonade,
255            llama_cpp_version_cuda: dp.llama_cpp_version_cuda,
256            api_endpoint_enabled: dp.api_endpoint_enabled,
257            api_endpoint_port: dp.api_endpoint_port,
258
259            spec_type: dp.spec_type,
260            draft_tokens: dp.draft_tokens,
261            tags: dp.tags,
262        }
263    }
264}
265
266/// How to handle GPU layer offloading.
267#[derive(
268    Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Hash, Default,
269)]
270pub enum GpuLayersMode {
271    #[default]
272    Auto,
273    Specific(u32),
274    All,
275}
276
277#[derive(Debug, Clone, PartialEq, Eq)]
278pub enum DownloadStatus {
279    Downloading,
280    Pausing,
281    Paused,
282    Complete,
283    Error(String),
284    Cancelled,
285}
286
287// ── Cache type enums ──────────────────────────────────────────
288
289/// Main KV cache data type.
290#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
291pub enum CacheType {
292    #[serde(rename = "f16")]
293    #[default]
294    F16,
295    #[serde(rename = "bf16")]
296    BF16,
297    #[serde(rename = "fq8_0")]
298    Fq8_0,
299    #[serde(rename = "fq4_1")]
300    Fq4_1,
301}
302
303impl std::fmt::Display for CacheType {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        match self {
306            CacheType::F16 => write!(f, "f16"),
307            CacheType::BF16 => write!(f, "bf16"),
308            CacheType::Fq8_0 => write!(f, "fq8_0"),
309            CacheType::Fq4_1 => write!(f, "fq4_1"),
310        }
311    }
312}
313
314/// KV cache quantization type.
315#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default, Hash)]
316pub enum CacheQuantType {
317    #[serde(rename = "f32")]
318    F32,
319    #[serde(rename = "f16")]
320    #[default]
321    F16,
322    #[serde(rename = "bf16")]
323    BF16,
324    #[serde(rename = "q8_0")]
325    Q8_0,
326    #[serde(rename = "q4_0")]
327    Q4_0,
328    #[serde(rename = "q4_1")]
329    Q4_1,
330    #[serde(rename = "iq4_nl")]
331    Iq4Nl,
332    #[serde(rename = "q5_0")]
333    Q5_0,
334    #[serde(rename = "q5_1")]
335    Q5_1,
336}
337
338pub type CacheTypeK = CacheQuantType;
339pub type CacheTypeV = CacheQuantType;
340
341impl CacheQuantType {
342    pub fn from_u8(n: u8) -> Self {
343        match n {
344            0 => Self::F32,
345            1 => Self::F16,
346            2 => Self::BF16,
347            3 => Self::Q8_0,
348            4 => Self::Q5_1,
349            5 => Self::Q5_0,
350            6 => Self::Q4_1,
351            7 => Self::Q4_0,
352            8 => Self::Iq4Nl,
353            _ => Self::F16,
354        }
355    }
356    pub fn next(&self) -> Self {
357        match self {
358            Self::F32 => Self::F16,
359            Self::F16 => Self::BF16,
360            Self::BF16 => Self::Q8_0,
361            Self::Q8_0 => Self::Q5_1,
362            Self::Q5_1 => Self::Q5_0,
363            Self::Q5_0 => Self::Q4_1,
364            Self::Q4_1 => Self::Q4_0,
365            Self::Q4_0 => Self::Iq4Nl,
366            Self::Iq4Nl => Self::F32,
367        }
368    }
369    pub fn prev(&self) -> Self {
370        match self {
371            Self::F32 => Self::Iq4Nl,
372            Self::F16 => Self::F32,
373            Self::BF16 => Self::F16,
374            Self::Q8_0 => Self::BF16,
375            Self::Q5_1 => Self::Q8_0,
376            Self::Q5_0 => Self::Q5_1,
377            Self::Q4_1 => Self::Q5_0,
378            Self::Q4_0 => Self::Q4_1,
379            Self::Iq4Nl => Self::Q4_0,
380        }
381    }
382}
383
384impl std::fmt::Display for CacheQuantType {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        match self {
387            Self::F32 => write!(f, "f32"),
388            Self::F16 => write!(f, "f16"),
389            Self::BF16 => write!(f, "bf16"),
390            Self::Q8_0 => write!(f, "q8_0"),
391            Self::Q4_0 => write!(f, "q4_0"),
392            Self::Q4_1 => write!(f, "q4_1"),
393            Self::Iq4Nl => write!(f, "iq4_nl"),
394            Self::Q5_0 => write!(f, "q5_0"),
395            Self::Q5_1 => write!(f, "q5_1"),
396        }
397    }
398}
399
400impl From<&str> for CacheQuantType {
401    fn from(s: &str) -> Self {
402        match s {
403            "F32" => Self::F32,
404            "F16" => Self::F16,
405            "BF16" => Self::BF16,
406            "Q8_0" => Self::Q8_0,
407            "Q4_0" => Self::Q4_0,
408            "Q4_1" => Self::Q4_1,
409            "Iq4Nl" => Self::Iq4Nl,
410            "Q5_0" => Self::Q5_0,
411            "Q5_1" => Self::Q5_1,
412            _ => Self::F16, // Default or error handling
413        }
414    }
415}
416
417/// Split mode for multi-GPU.
418#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
419pub enum SplitMode {
420    #[serde(rename = "none")]
421    None,
422    #[serde(rename = "layer")]
423    #[default]
424    Layer,
425    #[serde(rename = "row")]
426    Row,
427    #[serde(rename = "tensor")]
428    Tensor,
429}
430
431impl std::fmt::Display for SplitMode {
432    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433        match self {
434            SplitMode::None => write!(f, "none"),
435            SplitMode::Layer => write!(f, "layer"),
436            SplitMode::Row => write!(f, "row"),
437            SplitMode::Tensor => write!(f, "tensor"),
438        }
439    }
440}
441
442/// NUMA optimization mode.
443#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
444pub enum NumMode {
445    #[serde(rename = "none")]
446    #[default]
447    None,
448    #[serde(rename = "distribute")]
449    Distribute,
450    #[serde(rename = "isolate")]
451    Isolate,
452    #[serde(rename = "numactl")]
453    Numactl,
454}
455
456impl std::fmt::Display for NumMode {
457    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
458        match self {
459            NumMode::None => write!(f, "none"),
460            NumMode::Distribute => write!(f, "distribute"),
461            NumMode::Isolate => write!(f, "isolate"),
462            NumMode::Numactl => write!(f, "numactl"),
463        }
464    }
465}
466
467/// RoPE frequency scaling method.
468#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
469pub enum RopeScaling {
470    #[serde(rename = "none")]
471    #[default]
472    None,
473    #[serde(rename = "linear")]
474    Linear,
475    #[serde(rename = "yarn")]
476    Yarn,
477}
478
479impl std::fmt::Display for RopeScaling {
480    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481        match self {
482            RopeScaling::None => write!(f, "none"),
483            RopeScaling::Linear => write!(f, "linear"),
484            RopeScaling::Yarn => write!(f, "yarn"),
485        }
486    }
487}
488
489/// Mirostat version.
490#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
491pub enum Mirostat {
492    #[serde(rename = "0")]
493    #[default]
494    Off,
495    #[serde(rename = "1")]
496    V1,
497    #[serde(rename = "2")]
498    Mirostat2,
499}
500
501impl std::fmt::Display for Mirostat {
502    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503        match self {
504            Mirostat::Off => write!(f, "off"),
505            Mirostat::V1 => write!(f, "1"),
506            Mirostat::Mirostat2 => write!(f, "2"),
507        }
508    }
509}
510
511/// Sampler order string (semicolon-separated).
512/// Common types: penalties, dry, top_n_sigma, top_k, typ_p, top_p, min_p, xtc, temperature
513#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
514pub struct Samplers(pub String);
515
516impl Default for Samplers {
517    fn default() -> Self {
518        Self("penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature".to_string())
519    }
520}
521
522impl std::fmt::Display for Samplers {
523    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524        write!(f, "{}", self.0)
525    }
526}
527
528/// Backend used to run the llama.cpp server.
529#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
530pub enum Backend {
531    #[serde(rename = "cpu")]
532    #[default]
533    Cpu,
534    #[serde(rename = "vulkan")]
535    Vulkan,
536    #[serde(rename = "rocm")]
537    Rocm,
538    #[serde(rename = "rocm_lemonade")]
539    RocmLemonade,
540    #[serde(rename = "cuda")]
541    Cuda,
542    #[serde(rename = "cpu_arm64")]
543    CpuArm64,
544    #[serde(rename = "win_cpu")]
545    CpuWindows,
546    #[serde(rename = "win_vulkan")]
547    VulkanWindows,
548    #[serde(rename = "win_cuda_12_4")]
549    CudaWindows12_4,
550    #[serde(rename = "win_cuda_13_1")]
551    CudaWindows13_1,
552    #[serde(rename = "win_hip")]
553    HipWindows,
554    #[serde(rename = "macos_arm64")]
555    CpuMacosArm64,
556    #[serde(rename = "macos_x64")]
557    CpuMacosX64,
558}
559
560impl Backend {
561    /// Get the identifier used for directory names and asset prefixes.
562    pub fn slug(&self) -> &'static str {
563        match self {
564            Backend::Cpu => "cpu",
565            Backend::Vulkan => "vulkan",
566            Backend::Rocm => "rocm",
567            Backend::RocmLemonade => "rocm-lemonade",
568            Backend::Cuda => "cuda",
569            Backend::CpuArm64 => "cpu-arm64",
570            Backend::CpuWindows => "win-cpu",
571            Backend::VulkanWindows => "win-vulkan",
572            Backend::CudaWindows12_4 => "win-cuda-12.4",
573            Backend::CudaWindows13_1 => "win-cuda-13.1",
574            Backend::HipWindows => "win-hip",
575            Backend::CpuMacosArm64 => "macos-arm64",
576            Backend::CpuMacosX64 => "macos-x64",
577        }
578    }
579
580    /// Returns true if this backend is for Linux.
581    pub fn is_linux(self) -> bool {
582        matches!(
583            self,
584            Backend::Cpu
585                | Backend::Vulkan
586                | Backend::Rocm
587                | Backend::RocmLemonade
588                | Backend::Cuda
589                | Backend::CpuArm64
590        )
591    }
592
593    /// Returns true if this backend is for Windows.
594    pub fn is_windows(self) -> bool {
595        matches!(
596            self,
597            Backend::CpuWindows
598                | Backend::VulkanWindows
599                | Backend::CudaWindows12_4
600                | Backend::CudaWindows13_1
601                | Backend::HipWindows
602        )
603    }
604
605    /// Returns true if this backend is for macOS.
606    pub fn is_macos(self) -> bool {
607        matches!(self, Backend::CpuMacosArm64 | Backend::CpuMacosX64)
608    }
609
610    /// Parse backend from string representation.
611    pub fn from_str(s: &str) -> Self {
612        let s = s.to_lowercase();
613        if s.starts_with("vulkan") || s.starts_with("vk") {
614            Backend::Vulkan
615        } else if s.starts_with("rocm") || s.starts_with("ro") {
616            if s.contains("lemonade") {
617                Backend::RocmLemonade
618            } else {
619                Backend::Rocm
620            }
621        } else if s.starts_with("cuda") || s.starts_with("cu") {
622            Backend::Cuda
623        } else {
624            Backend::Cpu // Default
625        }
626    }
627}
628
629impl std::fmt::Display for Backend {
630    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631        write!(f, "{}", self.slug())
632    }
633}
634
635/// Server mode: normal (single model) or router (multiple models).
636#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
637pub enum ServerMode {
638    #[serde(rename = "normal")]
639    #[default]
640    Normal,
641    #[serde(rename = "router")]
642    Router,
643    #[serde(rename = "bench_gpu", alias = "bench")]
644    Bench,
645    #[serde(rename = "bench_tune")]
646    BenchTune,
647}
648
649impl std::fmt::Display for ServerMode {
650    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
651        match self {
652            ServerMode::Normal => write!(f, "Normal"),
653            ServerMode::Router => write!(f, "Router (XP!)"),
654            ServerMode::Bench => write!(f, "Bench GPU"),
655            ServerMode::BenchTune => write!(f, "BenchTune"),
656        }
657    }
658}
659
660// ── ModelSettings ─────────────────────────────────────────────
661//
662// WHEN ADDING A NEW PARAMETER to ModelSettings, update ALL of these locations:
663//   1.  src/models.rs:668       — ModelSettings struct field + doc comment
664//   2.  src/models.rs:179       — From<DefaultParams> for ModelSettings (field mapping)
665//   3.  src/config.rs:564       — DefaultParams struct field + serde attribute
666//   4.  src/config.rs:785       — DefaultParams Default impl (default value)
667//   5.  src/config.rs:175       — ModelOverride struct field (Option<T>)
668//   6.  src/config.rs:299       — ModelOverride::from_settings() (Some(s.field))
669//   7.  src/config.rs:385       — ModelOverride::apply() (one of the 3 macro calls)
670//   8.  src/tui/settings.rs:312 — all_fields() SettingField entry (id, name, section, etc.)
671//   9.  src/tui/settings.rs:1149— profile_settings_parts() diff macro call
672//  10.  src/tui/app/profiles.rs:80 — settings_fingerprint() hash call
673//  11.  src/tui/event/helpers.rs:125 — sync_global_settings() (if global-scoped)
674//  12.  src/tui/event/panel/settings.rs — key handlers for numeric/toggle edit
675//
676// The derived PartialEq on ModelSettings and DefaultParams guarantees
677// is_dirty() compares ALL fields — no field can be missed.
678// The build.rs script checks field counts at compile time.
679
680/// Settings for loading a model via llama.cpp server.
681#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
682pub struct ModelSettings {
683    // ── Loading ──────────────────────────────────────────────
684    /// Size of the prompt context.
685    pub context_length: u32,
686    /// Number of CPU threads for generation.
687    pub threads: u32,
688    /// Number of CPU threads for batch processing.
689    pub threads_batch: u32,
690    /// Logical maximum batch size.
691    pub batch_size: u32,
692    /// Physical maximum batch size (micro-batch).
693    pub ubatch_size: u32,
694    /// Max concurrent predictions (sequences).
695    pub parallel: u32,
696    /// Max concurrent predictions (requests in flight). None means no --parallel argument.
697    pub max_concurrent_predictions: Option<u32>,
698    /// Use uniform (unified) KV cache across all sequences.
699    pub uniform_cache: bool,
700    /// Offload KV cache to system RAM.
701    pub kv_cache_offload: bool,
702    /// KV cache data type for K.
703    pub cache_type_k: Option<CacheTypeK>,
704    /// KV cache data type for V.
705    pub cache_type_v: Option<CacheTypeV>,
706    /// Keep N tokens from the initial prompt.
707    pub keep: i32,
708    /// Use full-size SWA cache.
709    pub swa_full: bool,
710    /// Force system to keep model in RAM.
711    pub mlock: bool,
712    /// Memory-map the model.
713    pub mmap: bool,
714    /// NUMA optimization.
715    pub numa: NumMode,
716    /// System prompt.
717    pub system_prompt: String,
718    /// Name of the system prompt preset currently selected.
719    pub system_prompt_preset_name: String,
720
721    // ── GPU ──────────────────────────────────────────────────
722    /// GPU layer offloading mode.
723    pub gpu_layers_mode: GpuLayersMode,
724    /// Split mode across multiple GPUs.
725    pub split_mode: SplitMode,
726    /// Fraction of model offloaded to each GPU (comma-separated).
727    pub tensor_split: String,
728    /// Main GPU index.
729    pub main_gpu: i32,
730    /// Whether to adjust arguments to fit device memory.
731    pub fit: bool,
732    /// Path to LoRA adapter.
733    pub lora: Option<PathBuf>,
734    /// Path to LoRA adapter with scale.
735    pub lora_scaled: Option<(PathBuf, f32)>,
736    /// RPC servers.
737    pub rpc: String,
738    /// Restrict to embedding use case.
739    pub embedding: bool,
740    /// Enable Flash Attention.
741    pub flash_attn: bool,
742    /// Active experts per token (MoE models, -1 = model default).
743    pub expert_count: i32,
744    /// Use Jinja template engine for chat.
745    pub jinja: bool,
746    /// Custom chat template string.
747    pub chat_template: Option<String>,
748    /// JSON string for --chat-template-kwargs (e.g. {"enable_thinking": false}).
749    pub chat_template_kwargs: Option<String>,
750
751    // ── Sampling ─────────────────────────────────────────────
752    /// RNG seed (-1 = random).
753    pub seed: i32,
754    /// Temperature.
755    pub temperature: f32,
756    /// Top-k sampling (0 = disabled).
757    pub top_k: i32,
758    /// Top-p sampling (1.0 = disabled).
759    pub top_p: f32,
760    /// Minimum probability for a token.
761    pub min_p: f32,
762    /// Locally typical sampling parameter p.
763    pub typical_p: f32,
764    /// Mirostat version (0=off, 1=Mirostat, 2=Mirostat2).
765    pub mirostat: Mirostat,
766    /// Mirostat learning rate (eta).
767    pub mirostat_lr: f32,
768    /// Mirostat target entropy (tau).
769    pub mirostat_ent: f32,
770    /// Ignore end-of-stream token.
771    pub ignore_eos: bool,
772    /// Sampler order string.
773    pub samplers: Samplers,
774
775    // ── Repetition Control ───────────────────────────────────
776    /// Penalize repeat sequence of tokens.
777    pub repeat_penalty: f32,
778    /// Last N tokens to consider for repeat penalty.
779    pub repeat_last_n: i32,
780    /// Repeat alpha presence penalty.
781    pub presence_penalty: Option<f32>,
782    /// Repeat alpha frequency penalty.
783    pub frequency_penalty: Option<f32>,
784    /// DRY sampling multiplier.
785    pub dry_multiplier: f32,
786    /// DRY sampling base value.
787    pub dry_base: f32,
788    /// DRY allowed length.
789    pub dry_allowed_length: i32,
790    /// DRY penalty last N.
791    pub dry_penalty_last_n: i32,
792
793    // ── RoPE ─────────────────────────────────────────────────
794    /// RoPE frequency scaling method.
795    pub rope_scaling: RopeScaling,
796    /// RoPE context scaling factor.
797    pub rope_scale: f32,
798    /// RoPE base frequency.
799    pub rope_freq_base: f32,
800    /// RoPE frequency scaling factor.
801    pub rope_freq_scale: f32,
802    /// Enable Yarn RoPE scaling mode.
803    pub rope_yarn_enabled: bool,
804
805    // ── Server ───────────────────────────────────────────────
806    /// Host address.
807    pub host: String,
808    /// Port.
809    pub port: u16,
810    /// Server timeout in seconds.
811    pub timeout: u32,
812    /// Whether to enable prompt caching.
813    pub cache_prompt: bool,
814    /// Min chunk size for cache reuse.
815    pub cache_reuse: u32,
816    /// Whether to enable WebUI.
817    pub webui: bool,
818
819    // ── Other ────────────────────────────────────────────────
820    /// Max tokens to predict.
821    pub max_tokens: Option<u32>,
822    /// Cache type (legacy, kept for compatibility).
823    pub cache_type: CacheType,
824    /// Backend (cpu/vulkan).
825    pub backend: Backend,
826    /// llama.cpp release tag for CPU backend (e.g. "b1234" or None for latest).
827    pub llama_cpp_version_cpu: Option<String>,
828    /// llama.cpp release tag for Vulkan backend (e.g. "b1234" or None for latest).
829    pub llama_cpp_version_vulkan: Option<String>,
830    /// llama.cpp release tag for ROCm backend (e.g. "b1234" or None for latest).
831    pub llama_cpp_version_rocm: Option<String>,
832    /// Lemonade llama.cpp release tag for ROCm backend.
833    pub llama_cpp_version_rocm_lemonade: Option<String>,
834    /// llama.cpp release tag for CUDA backend.
835    pub llama_cpp_version_cuda: Option<String>,
836    /// Whether to enable the API proxy server.
837    pub api_endpoint_enabled: bool,
838    /// Port for the API proxy server.
839    pub api_endpoint_port: u16,
840    /// Speculative decoding type (e.g., "draft-mtp", "ngram-simple", "" for off).
841    pub spec_type: String,
842    /// Number of draft tokens for MTP.
843    pub draft_tokens: u32,
844    /// Tags for the model.
845    pub tags: Vec<String>,
846}
847
848impl Default for ModelSettings {
849    fn default() -> Self {
850        let mut s: Self = crate::config::DefaultParams::default().into();
851        // Override fields that differ from DefaultParams defaults
852        s.uniform_cache = false;
853        s.cache_type_k = Some(CacheTypeK::F16);
854        s.cache_type_v = Some(CacheTypeV::F16);
855        s.cache_type = CacheType::default();
856        s.backend = Backend::Cpu;
857        s.presence_penalty = Some(0.0);
858        s.frequency_penalty = Some(0.0);
859        s
860    }
861}
862
863impl ModelSettings {
864    /// Create ModelSettings from config defaults, applying model-specific overrides.
865    pub fn from_config(config: &crate::config::Config) -> Self {
866        config.default.clone().into()
867    }
868}
869
870/// Default benchmark prompt used when starting a tuning session.
871pub const BENCHMARK_PROMPT: &str = "Create Mona Lisa image in ascii art using text, number, symbol, everything possible. this should be the perfect painting.";
872
873/// A discovered model file.
874#[derive(Debug, Clone)]
875pub struct DiscoveredModel {
876    pub path: PathBuf,
877    pub name: String,
878    pub file_size: u64,
879    pub display_name: String, // path relative to model_dir for display
880}
881
882/// Parsed GGUF metadata for a model, cached to avoid re-parsing the file.
883#[derive(Debug, Clone, Default)]
884pub struct GgufMetadata {
885    pub layers: u32,
886    pub hidden_size: u32,
887    pub n_ctx_train: u32,
888    pub n_head: u32,
889    pub n_kv_head: u32,
890    pub arch: String,
891    pub file_type: String,
892    pub quantization: String,
893    pub model_parameters: String,
894    pub domain: String,
895    pub capabilities: Vec<String>,
896    pub tokenizer: String,
897    pub vocab_size: u32,
898    pub draft_tokens: u32,
899}
900
901impl GgufMetadata {
902    pub fn from_path(path: &std::path::Path) -> anyhow::Result<Self> {
903        let path_str = path.to_string_lossy();
904        let mut container = gguf_rs::get_gguf_container(&path_str)
905            .map_err(|e| anyhow::anyhow!("Failed to get GGUF container: {}", e))?;
906        let model_data = container
907            .decode()
908            .map_err(|e| anyhow::anyhow!("Failed to decode GGUF: {}", e))?;
909
910        let mut meta = Self::default();
911
912        let extract_str = |key: &str| -> String {
913            model_data
914                .metadata()
915                .get(key)
916                .and_then(|v| v.as_str().map(|s| s.to_string()))
917                .unwrap_or_default()
918        };
919
920        let extract_num = |key: &str| -> Option<u64> {
921            model_data.metadata().get(key).and_then(|v| {
922                v.as_u64()
923                    .or_else(|| v.as_i64().map(|x| x as u64))
924                    .or_else(|| v.as_f64().map(|x| x as u64))
925            })
926        };
927
928        meta.arch = extract_str("general.architecture");
929        let prefix = if meta.arch.is_empty() {
930            "llama"
931        } else {
932            &meta.arch
933        };
934
935        let get_num_with_fallback = |suffix: &str| -> u32 {
936            extract_num(&format!("{}.{}", prefix, suffix))
937                .or_else(|| {
938                    if prefix != "llama" {
939                        extract_num(&format!("llama.{}", suffix))
940                    } else {
941                        None
942                    }
943                })
944                .unwrap_or(0) as u32
945        };
946
947        meta.layers = get_num_with_fallback("block_count");
948        meta.hidden_size = get_num_with_fallback("embedding_length");
949        meta.n_ctx_train = get_num_with_fallback("context_length");
950        meta.n_head = get_num_with_fallback("attention.head_count");
951        meta.n_kv_head = get_num_with_fallback("attention.head_count_kv");
952
953        if let Some(value) = model_data.metadata().get("tokenizer.ggml.tokens")
954            && let Some(arr) = value.as_array()
955        {
956            meta.vocab_size = arr.len() as u32;
957        }
958
959        if meta.arch == "mtp" {
960            meta.draft_tokens = extract_num("mtp.draft_tokens").unwrap_or(0) as u32;
961        }
962
963        if let Some(v) = extract_num("general.file_type") {
964            meta.file_type = match v {
965                0 => "F32".to_string(),
966                1 => "F16".to_string(),
967                2 => "Q4_0".to_string(),
968                3 => "Q4_1".to_string(),
969                7 => "Q8_0".to_string(),
970                8 => "Q5_0".to_string(),
971                9 => "Q5_1".to_string(),
972                10 => "Q2_K".to_string(),
973                11 => "Q3_K_S".to_string(),
974                12 => "Q3_K_M".to_string(),
975                13 => "Q3_K_L".to_string(),
976                14 => "Q4_K_S".to_string(),
977                15 => "Q4_K_M".to_string(),
978                16 => "Q5_K_S".to_string(),
979                17 => "Q5_K_M".to_string(),
980                18 => "Q6_K".to_string(),
981                19 => "IQ2_XXS".to_string(),
982                20 => "IQ2_XS".to_string(),
983                21 => "IQ3_XXS".to_string(),
984                22 => "IQ1_S".to_string(),
985                23 => "IQ4_NL".to_string(),
986                24 => "IQ3_S".to_string(),
987                25 => "IQ2_S".to_string(),
988                26 => "IQ4_XS".to_string(),
989                _ => format!("Unknown ({})", v),
990            };
991        }
992
993        if let Some(value) = model_data.metadata().get("general.capabilities")
994            && let Some(arr) = value.as_array()
995        {
996            for v in arr {
997                if let Some(s) = v.as_str() {
998                    meta.capabilities.push(s.to_string());
999                }
1000            }
1001        }
1002
1003        if model_data
1004            .metadata()
1005            .contains_key("tokenizer.chat_template")
1006        {
1007            meta.capabilities.push("chat".to_string());
1008        }
1009
1010        meta.tokenizer = extract_str("tokenizer.ggml.model");
1011        meta.domain = extract_str("general.domain");
1012        meta.model_parameters = model_data.model_parameters();
1013
1014        Ok(meta)
1015    }
1016}
1017
1018/// Metrics reported by the llama.cpp server.
1019#[derive(Debug, Clone)]
1020pub struct ServerMetrics {
1021    pub loaded: bool,
1022    pub tps: f64,
1023    pub prompt_tps: f64,
1024    pub cpu_usage: f64,
1025    pub gpu_mem_used: u64,
1026    pub gpu_mem_total: u64,
1027    pub ram_used: u64,
1028    pub ctx_used: u32,
1029    pub ctx_max: u32,
1030    /// Sum of gpu_mem_used across all loaded models (for Total VRAM display).
1031    pub total_vram_used: u64,
1032    /// Number of decoded tokens from print_timing logs.
1033    pub decoded_tokens: u64,
1034    /// Generation tokens per second parsed from llama.cpp log output (e.g., "tg = 64.45 t/s").
1035    pub gen_tps: f64,
1036    /// Estimated latency per generated token in milliseconds.
1037    pub latency_per_token_ms: f64,
1038    /// Estimated prompt processing latency in milliseconds (1000 / prompt_tps).
1039    pub prompt_latency_ms: f64,
1040}
1041
1042/// GPU device buffer reported by llama-server during model loading.
1043#[derive(Debug, Clone)]
1044pub struct GPUBuffer {
1045    pub device: String,
1046    pub buffer_size_mib: f64,
1047}
1048
1049/// Progress information during model loading, parsed from llama-server log output.
1050#[derive(Debug, Clone, Default)]
1051pub struct LoadProgress {
1052    /// Total number of layers in the model.
1053    pub layers_total: Option<u32>,
1054    /// Number of layers already offloaded to GPU.
1055    pub layers_loaded: Option<u32>,
1056    /// Total number of tensors in the model (from "Loading tensor X of Y" log).
1057    pub tensors_total: Option<u32>,
1058    /// Number of tensors loaded (counted from dot-lines in log).
1059    pub tensors_loaded: u32,
1060    /// GPU device buffers with their sizes.
1061    pub buffers: Vec<GPUBuffer>,
1062}
1063
1064impl Default for ServerMetrics {
1065    fn default() -> Self {
1066        Self {
1067            loaded: false,
1068            tps: 0.0,
1069            prompt_tps: 0.0,
1070            cpu_usage: 0.0,
1071            gpu_mem_used: 0,
1072            gpu_mem_total: 0,
1073            ram_used: 0,
1074            ctx_used: 0,
1075            ctx_max: 0,
1076            total_vram_used: 0,
1077            decoded_tokens: 0,
1078            gen_tps: 0.0,
1079            latency_per_token_ms: 0.0,
1080            prompt_latency_ms: 0.0,
1081        }
1082    }
1083}
1084
1085/// WebSocket-friendly metrics snapshot (serializable, no internal state).
1086#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1087pub struct WsMetrics {
1088    pub model_name: String,
1089    pub loaded: bool,
1090    pub state: String,
1091    pub tps: f64,
1092    pub prompt_tps: f64,
1093    pub ctx_used: u32,
1094    pub ctx_max: u32,
1095    pub cpu_usage: f64,
1096    pub gpu_mem_used: u64,
1097    pub gpu_mem_total: u64,
1098    pub ram_used: u64,
1099    pub latency_per_token_ms: f64,
1100    pub decoded_tokens: u64,
1101    pub gen_tps: f64,
1102    pub timestamp: u64,
1103    // Server command
1104    pub cmd_display: Option<String>,
1105    // LLM settings
1106    pub threads: u32,
1107    pub threads_batch: u32,
1108    pub context_length: u32,
1109    pub ubatch_size: u32,
1110    pub batch_size: u32,
1111    pub temperature: f32,
1112    pub top_k: u32,
1113    pub top_p: f32,
1114    pub min_p: f32,
1115    pub typical_p: f32,
1116    pub seed: i32,
1117    pub repeat_penalty: f32,
1118    pub repeat_last_n: i32,
1119    pub presence_penalty: Option<f32>,
1120    pub frequency_penalty: Option<f32>,
1121    pub mirostat: Option<u32>,
1122    pub mirostat_lr: Option<f32>,
1123    pub mirostat_ent: Option<f32>,
1124    pub max_tokens: Option<u32>,
1125    pub flash_attn: bool,
1126    pub kv_cache_offload: bool,
1127    pub cache_type_k: Option<String>,
1128    pub cache_type_v: Option<String>,
1129    pub uniform_cache: bool,
1130    pub mlock: bool,
1131    pub mmap: bool,
1132    pub embedding: bool,
1133    pub jinja: bool,
1134    pub ignore_eos: bool,
1135    pub samplers: String,
1136    pub expert_count: i32,
1137    pub gpu_layers: String,
1138    pub backend: String,
1139    pub llama_cpp_version: String,
1140    pub spec_type: String,
1141    pub draft_tokens: u32,
1142}
1143
1144impl WsMetrics {
1145    pub fn from_metrics(
1146        metrics: &ServerMetrics,
1147        model_name: &str,
1148        state: &str,
1149        settings: &crate::models::ModelSettings,
1150        cmd_display: Option<&str>,
1151    ) -> Self {
1152        use std::time::{SystemTime, UNIX_EPOCH};
1153        let timestamp = SystemTime::now()
1154            .duration_since(UNIX_EPOCH)
1155            .map(|d| d.as_secs())
1156            .unwrap_or(0);
1157        let gpu_layers = match settings.gpu_layers_mode {
1158            crate::models::GpuLayersMode::Auto => "Auto".to_string(),
1159            crate::models::GpuLayersMode::Specific(n) => n.to_string(),
1160            crate::models::GpuLayersMode::All => "All".to_string(),
1161        };
1162        Self {
1163            model_name: model_name.to_string(),
1164            loaded: metrics.loaded,
1165            state: state.to_string(),
1166            tps: metrics.tps,
1167            prompt_tps: metrics.prompt_tps,
1168            ctx_used: metrics.ctx_used,
1169            ctx_max: metrics.ctx_max,
1170            cpu_usage: metrics.cpu_usage,
1171            gpu_mem_used: metrics.gpu_mem_used,
1172            gpu_mem_total: metrics.gpu_mem_total,
1173            ram_used: metrics.ram_used,
1174            latency_per_token_ms: metrics.latency_per_token_ms,
1175            decoded_tokens: metrics.decoded_tokens,
1176            gen_tps: metrics.gen_tps,
1177            timestamp,
1178            cmd_display: cmd_display.map(String::from),
1179            threads: settings.threads,
1180            threads_batch: settings.threads_batch,
1181            context_length: settings.context_length,
1182            ubatch_size: settings.ubatch_size,
1183            batch_size: settings.batch_size,
1184            temperature: settings.temperature,
1185            top_k: settings.top_k as u32,
1186            top_p: settings.top_p,
1187            min_p: settings.min_p,
1188            typical_p: settings.typical_p,
1189            seed: settings.seed,
1190            repeat_penalty: settings.repeat_penalty,
1191            repeat_last_n: settings.repeat_last_n,
1192            presence_penalty: settings.presence_penalty,
1193            frequency_penalty: settings.frequency_penalty,
1194            mirostat: Some(match settings.mirostat {
1195                crate::models::Mirostat::Off => 0,
1196                crate::models::Mirostat::V1 => 1,
1197                crate::models::Mirostat::Mirostat2 => 2,
1198            }),
1199            mirostat_lr: Some(settings.mirostat_lr),
1200            mirostat_ent: Some(settings.mirostat_ent),
1201            max_tokens: settings.max_tokens,
1202            flash_attn: settings.flash_attn,
1203            kv_cache_offload: settings.kv_cache_offload,
1204            cache_type_k: settings.cache_type_k.map(|k| k.to_string()),
1205            cache_type_v: settings.cache_type_v.map(|k| k.to_string()),
1206            uniform_cache: settings.uniform_cache,
1207            mlock: settings.mlock,
1208            mmap: settings.mmap,
1209            embedding: settings.embedding,
1210            jinja: settings.jinja,
1211            ignore_eos: settings.ignore_eos,
1212            samplers: settings.samplers.to_string(),
1213            expert_count: settings.expert_count,
1214            gpu_layers,
1215            backend: settings.backend.to_string(),
1216            llama_cpp_version: settings.get_active_backend_version_display().to_string(),
1217            spec_type: settings.spec_type.clone(),
1218            draft_tokens: settings.draft_tokens,
1219        }
1220    }
1221}
1222
1223/// Estimate VRAM usage (in MiB) for a model with the given settings.
1224///
1225/// Model file size is the size of the GGUF file in MiB. The model itself
1226/// takes 1x its size in VRAM (loaded as-is). KV cache is the dominant
1227/// variable cost — it scales with context_length, batch_size, and layers.
1228///
1229/// The KV cache formula accounts for:
1230/// - Actual GQA ratio from model metadata (n_kv_head / n_head)
1231/// - FlashAttention: reduces KV cache storage by ~2x
1232/// - Unified KV cache: shares KV across sequences, dividing by parallel count
1233/// - KV cache quantization (q4_0, q5_0, q8_0, etc.)
1234pub fn estimate_vram_mib(
1235    model_mib: u64,
1236    settings: &ModelSettings,
1237    total_layers: u32,
1238    hidden_size_opt: Option<u32>,
1239    n_head_opt: Option<u32>,
1240    n_kv_head_opt: Option<u32>,
1241    gpu_mem_total_mib: u64,
1242) -> u64 {
1243    let model_mib_f = model_mib as f64;
1244
1245    // Compute how much of the model is loaded into VRAM based on GPU layers.
1246    let gpu_layers = match settings.gpu_layers_mode {
1247        GpuLayersMode::Auto => {
1248            // Heuristic: ~60% of layers when Auto (llama.cpp will decide at runtime)
1249            if total_layers > 0 {
1250                (total_layers as f64 * 0.6) as u32
1251            } else {
1252                20
1253            }
1254        }
1255        GpuLayersMode::Specific(n) => {
1256            if total_layers > 0 {
1257                n.min(total_layers)
1258            } else {
1259                n
1260            }
1261        }
1262        GpuLayersMode::All => {
1263            if total_layers > 0 {
1264                total_layers
1265            } else {
1266                32
1267            }
1268        }
1269    };
1270
1271    // Model weights loaded into VRAM: proportional to GPU layers.
1272    let model_vram = if total_layers > 0 && gpu_layers > 0 {
1273        model_mib_f * (gpu_layers as f64 / total_layers as f64).min(1.0)
1274    } else if gpu_layers > 0 {
1275        model_mib_f
1276    } else {
1277        0.0
1278    };
1279
1280    if matches!(settings.gpu_layers_mode, GpuLayersMode::Specific(0)) {
1281        return 0; // CPU only
1282    }
1283
1284    // Heuristic for hidden_size if not provided:
1285    // A 7B model (4-bit) is ~4000 MiB and has hidden=4096.
1286    let hidden_size = match hidden_size_opt {
1287        Some(h) => h as f64,
1288        None => {
1289            let params_est = model_mib_f / 550.0;
1290            (1024.0 * params_est.sqrt().max(1.0) * 1.5).max(512.0)
1291        }
1292    };
1293
1294    // ── KV cache estimation ─────────────────────────────────────
1295
1296    // GQA ratio: real KV heads vs query heads.
1297    // If n_kv_head == n_head, ratio = 1.0 (no reduction).
1298    // If n_kv_head < n_head, ratio < 1.0 (KV cache is smaller).
1299    let gqa_ratio = match (n_head_opt, n_kv_head_opt) {
1300        (Some(n_head), Some(n_kv_head)) if n_head > 0 => n_kv_head as f64 / n_head as f64,
1301        _ => 1.0, // fallback: assume no GQA
1302    };
1303
1304    // FlashAttention reduces KV cache storage by ~2x because it doesn't
1305    // need to keep the full attention matrix in memory.
1306    let flash_attn_factor = if settings.flash_attn { 0.5 } else { 1.0 };
1307
1308    // Unified KV cache shares a single KV buffer across all sequences.
1309    let uniform_cache_factor = if settings.uniform_cache {
1310        1.0 / settings.parallel as f64
1311    } else {
1312        1.0
1313    };
1314
1315    // Effective context length: YaRN RoPE scale extends the usable context.
1316    let effective_ctx = settings.context_length as f64 * settings.rope_scale as f64;
1317
1318    // KV cache in MiB:
1319    // Formula: 2 * n_layer * n_ctx * n_embd_kv * sizeof(type)
1320    // n_embd_kv = hidden_size * gqa_ratio
1321    // The KV cache is allocated for the total number of model layers,
1322    // not just the number layers loaded into the GPU (gpu_layers).
1323    // However only gpu_layers * sizeof(type) contributes to the VRAM cost.
1324    let kv_mib = (2.0
1325        * hidden_size
1326        * effective_ctx
1327        * total_layers as f64
1328        * gqa_ratio
1329        * gpu_layers as f64
1330        / total_layers as f64  // VRAM cost: only GPU-loaded portion of KV cache
1331        * flash_attn_factor
1332        * uniform_cache_factor
1333        * kv_quant_bytes(
1334            settings.cache_type_k.unwrap_or(CacheTypeK::F16),
1335            settings.cache_type_v.unwrap_or(CacheTypeV::F16)
1336        ))
1337        / (1024.0 * 1024.0);
1338
1339    // Activation overhead during inference (proportional to batch * hidden).
1340    // Increased multiplier to 8.0 (from 2.0) to be more pessimistic about scratch buffers.
1341    let activation_mib = (settings.batch_size as f64 * hidden_size * 8.0) / (1024.0 * 1024.0);
1342
1343    // Fixed overhead for driver, fragmentation, and small meta buffers.
1344    // Use 3.8% of max VRAM, falling back to 500MiB if unknown.
1345    let fixed_overhead = if gpu_mem_total_mib > 0 {
1346        gpu_mem_total_mib as f64 * 0.038
1347    } else {
1348        500.0
1349    };
1350
1351    let total_mib = model_vram + kv_mib + activation_mib + fixed_overhead + 550.0;
1352
1353    total_mib.ceil() as u64
1354}
1355
1356/// Return the average KV cache element size in bytes for the given K/V types.
1357///
1358/// KV cache stores K and V separately, potentially at different precisions.
1359/// We average the two to get a single per-element size.
1360fn kv_quant_bytes(k_type: CacheQuantType, v_type: CacheQuantType) -> f64 {
1361    let get_bytes = |t: CacheQuantType| match t {
1362        CacheQuantType::F32 => 4.0,
1363        CacheQuantType::F16 | CacheQuantType::BF16 => 2.0,
1364        CacheQuantType::Q8_0 => 1.0,
1365        CacheQuantType::Q5_0 | CacheQuantType::Q5_1 => 0.625, // 5 bits
1366        CacheQuantType::Q4_0 | CacheQuantType::Q4_1 | CacheQuantType::Iq4Nl => 0.5, // 4 bits
1367    };
1368    (get_bytes(k_type) + get_bytes(v_type)) / 2.0
1369}
1370
1371impl ModelSettings {
1372    /// Check if this settings differs from `other` in any field.
1373    /// Uses derived PartialEq which compares all fields — compiler-enforced.
1374    pub fn is_dirty(&self, other: &Self) -> bool {
1375        self != other
1376    }
1377}
1378
1379// Benchmark Tuning types
1380#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1381pub struct BenchTuneConfig {
1382    pub model_path: PathBuf,
1383    pub num_iterations: u32,
1384    pub prompt: String,
1385    pub params_to_test: Vec<BenchTuneParam>,
1386    pub test_duration: Duration,
1387    pub bench_mode: BenchTuneMode,
1388    pub n_predict: u32,
1389    pub chat_template_kwargs: Option<String>,
1390    pub test_timeout: Duration,
1391}
1392
1393#[derive(Debug, Clone, Serialize, Deserialize)]
1394pub struct BenchTuneParam {
1395    pub name: String,
1396    pub min: f64,
1397    pub max: f64,
1398    pub step: f64,
1399    pub enabled: bool,
1400    pub variants: Vec<String>,
1401}
1402
1403impl PartialEq for BenchTuneParam {
1404    fn eq(&self, other: &Self) -> bool {
1405        self.name == other.name
1406            && self.min.to_bits() == other.min.to_bits()
1407            && self.max.to_bits() == other.max.to_bits()
1408            && self.step.to_bits() == other.step.to_bits()
1409            && self.enabled == other.enabled
1410            && self.variants == other.variants
1411    }
1412}
1413impl Eq for BenchTuneParam {}
1414
1415#[derive(Debug, Clone, Serialize, Deserialize)]
1416pub struct BenchTuneParamValue {
1417    pub temperature: Option<f64>,
1418    pub top_p: Option<f64>,
1419    pub top_k: Option<i64>,
1420    pub repeat_penalty: Option<f64>,
1421    pub context_length: Option<u32>,
1422    pub batch_size: Option<u32>,
1423    pub flash_attn: Option<bool>,
1424    pub threads: Option<u32>,
1425    pub expert_count: Option<i32>,
1426    pub spec_type: Option<String>,
1427    pub draft_tokens: Option<u32>,
1428}
1429
1430impl PartialEq for BenchTuneParamValue {
1431    fn eq(&self, other: &Self) -> bool {
1432        self.temperature.map(|v| v.to_bits()) == other.temperature.map(|v| v.to_bits())
1433            && self.top_p.map(|v| v.to_bits()) == other.top_p.map(|v| v.to_bits())
1434            && self.top_k == other.top_k
1435            && self.repeat_penalty.map(|v| v.to_bits()) == other.repeat_penalty.map(|v| v.to_bits())
1436            && self.context_length == other.context_length
1437            && self.batch_size == other.batch_size
1438            && self.flash_attn == other.flash_attn
1439            && self.threads == other.threads
1440            && self.expert_count == other.expert_count
1441            && self.spec_type == other.spec_type
1442            && self.draft_tokens == other.draft_tokens
1443    }
1444}
1445impl Eq for BenchTuneParamValue {}
1446
1447#[derive(Debug, Clone, Serialize, Deserialize)]
1448pub struct BenchTuneResult {
1449    pub params: BenchTuneParamValue,
1450    pub metrics: BenchTuneMetrics,
1451    pub outputs: Vec<String>,
1452    pub per_iteration_metrics: Vec<BenchTuneMetrics>,
1453    pub base_settings: Option<ModelSettings>,
1454    pub server_command: Option<String>,
1455}
1456
1457#[derive(Debug, Clone, Serialize, Deserialize)]
1458pub struct BenchTuneMetrics {
1459    pub prompt_tps: f64,
1460    pub generation_tps: f64,
1461    pub combined_tps: f64,
1462    pub latency_per_token: f64,
1463    pub first_token_time: f64,
1464}
1465
1466#[derive(Debug, Clone, Serialize, Deserialize)]
1467pub enum BenchTuneStatus {
1468    Running {
1469        current: usize,
1470        total: usize,
1471        progress: f32,
1472        current_params: BenchTuneParamValue,
1473    },
1474    Completed {
1475        total_tests: usize,
1476        successful_tests: usize,
1477        elapsed: Duration,
1478    },
1479    PartiallyCompleted {
1480        total_tests: usize,
1481        successful_tests: usize,
1482        failed_tests: usize,
1483        elapsed: Duration,
1484    },
1485    Cancelled {
1486        total_tests: usize,
1487        successful_tests: usize,
1488        failed_tests: usize,
1489        elapsed: Duration,
1490    },
1491    Error {
1492        error: String,
1493    },
1494}
1495
1496#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
1497pub enum BenchTuneMode {
1498    /// Runtime-only mode: sends all params in /completion request body, no server restarts
1499    RuntimeOnly,
1500    /// Full mode: spawns a new server for each parameter combination (tests server-level params)
1501    #[default]
1502    Full,
1503}
1504
1505/// Progress status for benchmark tuning
1506#[derive(Debug, Clone)]
1507pub enum BenchTuneProgress {
1508    /// Tuning is running.
1509    Running {
1510        current: usize,
1511        total: usize,
1512        progress: f32,
1513        current_params: BenchTuneParamValue,
1514    },
1515    /// Tuning is complete.
1516    Completed {
1517        total_tests: usize,
1518        successful_tests: usize,
1519        elapsed: Duration,
1520    },
1521    /// Tuning completed with some failures.
1522    PartiallyCompleted {
1523        total_tests: usize,
1524        successful_tests: usize,
1525        failed_tests: usize,
1526        elapsed: Duration,
1527    },
1528    /// Tuning was cancelled by the user.
1529    Cancelled {
1530        total_tests: usize,
1531        successful_tests: usize,
1532        failed_tests: usize,
1533        elapsed: Duration,
1534    },
1535    /// Tuning failed.
1536    Error { error: String },
1537}
1538
1539impl BenchTuneProgress {
1540    pub fn from_status(status: &BenchTuneStatus) -> Option<Self> {
1541        match status {
1542            BenchTuneStatus::Running {
1543                current,
1544                total,
1545                progress,
1546                current_params,
1547            } => Some(BenchTuneProgress::Running {
1548                current: *current,
1549                total: *total,
1550                progress: *progress,
1551                current_params: current_params.clone(),
1552            }),
1553            BenchTuneStatus::Completed {
1554                total_tests,
1555                successful_tests,
1556                elapsed,
1557            } => Some(BenchTuneProgress::Completed {
1558                total_tests: *total_tests,
1559                successful_tests: *successful_tests,
1560                elapsed: *elapsed,
1561            }),
1562            BenchTuneStatus::PartiallyCompleted {
1563                total_tests,
1564                successful_tests,
1565                failed_tests,
1566                elapsed,
1567            } => Some(BenchTuneProgress::PartiallyCompleted {
1568                total_tests: *total_tests,
1569                successful_tests: *successful_tests,
1570                failed_tests: *failed_tests,
1571                elapsed: *elapsed,
1572            }),
1573            BenchTuneStatus::Cancelled {
1574                total_tests,
1575                successful_tests,
1576                failed_tests,
1577                elapsed,
1578            } => Some(BenchTuneProgress::Cancelled {
1579                total_tests: *total_tests,
1580                successful_tests: *successful_tests,
1581                failed_tests: *failed_tests,
1582                elapsed: *elapsed,
1583            }),
1584            BenchTuneStatus::Error { error } => Some(BenchTuneProgress::Error {
1585                error: error.clone(),
1586            }),
1587        }
1588    }
1589}
1590
1591impl BenchTuneConfig {
1592    pub fn new(model_path: PathBuf, num_iterations: u32, prompt: String) -> Self {
1593        Self {
1594            model_path,
1595            num_iterations,
1596            prompt,
1597            params_to_test: vec![
1598                BenchTuneParam {
1599                    name: "temperature".to_string(),
1600                    min: 0.4,
1601                    max: 1.0,
1602                    step: 0.1,
1603                    enabled: false,
1604                    variants: vec![],
1605                },
1606                BenchTuneParam {
1607                    name: "top_p".to_string(),
1608                    min: 0.8,
1609                    max: 1.0,
1610                    step: 0.1,
1611                    enabled: false,
1612                    variants: vec![],
1613                },
1614                BenchTuneParam {
1615                    name: "top_k".to_string(),
1616                    min: 10.0,
1617                    max: 40.0,
1618                    step: 5.0,
1619                    enabled: false,
1620                    variants: vec![],
1621                },
1622                BenchTuneParam {
1623                    name: "repeat_penalty".to_string(),
1624                    min: 1.0,
1625                    max: 1.5,
1626                    step: 0.1,
1627                    enabled: false,
1628                    variants: vec![],
1629                },
1630                BenchTuneParam {
1631                    name: "flash_attn".to_string(),
1632                    min: 0.0,
1633                    max: 1.0,
1634                    step: 1.0,
1635                    enabled: false,
1636                    variants: vec![],
1637                },
1638                BenchTuneParam {
1639                    name: "threads".to_string(),
1640                    min: 4.0,
1641                    max: 16.0,
1642                    step: 4.0,
1643                    enabled: false,
1644                    variants: vec![],
1645                },
1646                BenchTuneParam {
1647                    name: "batch_size".to_string(),
1648                    min: 512.0,
1649                    max: 2048.0,
1650                    step: 512.0,
1651                    enabled: false,
1652                    variants: vec![],
1653                },
1654                BenchTuneParam {
1655                    name: "expert_count".to_string(),
1656                    min: -1.0,
1657                    max: 4.0,
1658                    step: 1.0,
1659                    enabled: false,
1660                    variants: vec![],
1661                },
1662                BenchTuneParam {
1663                    name: "spec_type".to_string(),
1664                    min: 0.0,
1665                    max: 8.0,
1666                    step: 1.0,
1667                    enabled: false,
1668                    variants: vec![
1669                        "Off".to_string(),
1670                        "draft-mtp".to_string(),
1671                        "draft-simple".to_string(),
1672                        "draft-eagle3".to_string(),
1673                        "ngram-simple".to_string(),
1674                        "ngram-map-k".to_string(),
1675                        "ngram-map-k4v".to_string(),
1676                        "ngram-mod".to_string(),
1677                        "ngram-cache".to_string(),
1678                    ],
1679                },
1680                BenchTuneParam {
1681                    name: "draft_tokens".to_string(),
1682                    min: 0.0,
1683                    max: 8.0,
1684                    step: 1.0,
1685                    enabled: false,
1686                    variants: vec![],
1687                },
1688            ],
1689            test_duration: Duration::from_secs(30),
1690            bench_mode: BenchTuneMode::default(),
1691            n_predict: 512,
1692            chat_template_kwargs: Some(r#"{"enable_thinking": false}"#.to_string()),
1693            test_timeout: Duration::from_secs(60),
1694        }
1695    }
1696
1697    /// Generate all parameter combinations based on the config
1698    pub fn generate_combinations(&self) -> Vec<BenchTuneParamValue> {
1699        let mut temp_values = vec![None];
1700        let mut top_p_values = vec![None];
1701        let mut top_k_values = vec![None];
1702        let mut repeat_penalty_values = vec![None];
1703        let mut flash_attn_values = vec![None];
1704        let mut threads_values = vec![None];
1705        let mut batch_size_values = vec![None];
1706        let mut expert_count_values = vec![None];
1707        let mut selected_spec_type = None;
1708        if let Some(p) = self.params_to_test.iter().find(|p| p.name == "spec_type") {
1709            let base_idx = (p.min as usize).min(p.variants.len().saturating_sub(1));
1710            let val = &p.variants[base_idx];
1711            if val != "Off" {
1712                selected_spec_type = Some(val.clone());
1713            }
1714        }
1715        let mut spec_type_values = vec![selected_spec_type];
1716        let mut draft_tokens_values = vec![None];
1717
1718        let spec_type_options = vec![
1719            "Off".to_string(),
1720            "draft-mtp".to_string(),
1721            "draft-simple".to_string(),
1722            "draft-eagle3".to_string(),
1723            "ngram-simple".to_string(),
1724            "ngram-map-k".to_string(),
1725            "ngram-map-k4v".to_string(),
1726            "ngram-mod".to_string(),
1727            "ngram-cache".to_string(),
1728        ];
1729
1730        for p in &self.params_to_test {
1731            if !p.enabled {
1732                continue;
1733            }
1734
1735            let vals: Vec<f64> = {
1736                let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1737                (0..=step_count)
1738                    .map(|i| (p.min + (i as f64 * p.step)).min(p.max))
1739                    .collect()
1740            };
1741
1742            match p.name.as_str() {
1743                "temperature" => temp_values = vals.into_iter().map(Some).collect(),
1744                "top_p" => top_p_values = vals.into_iter().map(Some).collect(),
1745                "top_k" => top_k_values = vals.into_iter().map(|v| Some(v as i64)).collect(),
1746                "repeat_penalty" => repeat_penalty_values = vals.into_iter().map(Some).collect(),
1747                "flash_attn" => {
1748                    flash_attn_values = vals.into_iter().map(|v| Some(v >= 0.5)).collect()
1749                }
1750                "threads" => threads_values = vals.into_iter().map(|v| Some(v as u32)).collect(),
1751                "batch_size" => {
1752                    batch_size_values = vals.into_iter().map(|v| Some(v as u32)).collect()
1753                }
1754                "expert_count" => {
1755                    expert_count_values = vals.into_iter().map(|v| Some(v as i32)).collect()
1756                }
1757                "spec_type" => {
1758                    if !p.variants.is_empty() {
1759                        spec_type_values = p.variants.clone().into_iter().map(Some).collect();
1760                    } else {
1761                        let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1762                        spec_type_values = (0..=step_count)
1763                            .map(|i| {
1764                                let idx = i.min(spec_type_options.len() - 1);
1765                                Some(spec_type_options[idx].clone())
1766                            })
1767                            .collect()
1768                    }
1769                }
1770                "draft_tokens" => {
1771                    draft_tokens_values = vals.into_iter().map(|v| Some(v as u32)).collect()
1772                }
1773                _ => {}
1774            }
1775        }
1776
1777        let mut combinations = Vec::new();
1778        for &temp in &temp_values {
1779            for &top_p in &top_p_values {
1780                for &top_k in &top_k_values {
1781                    for &rp in &repeat_penalty_values {
1782                        for &fa in &flash_attn_values {
1783                            for &th in &threads_values {
1784                                for &bs in &batch_size_values {
1785                                    for &ec in &expert_count_values {
1786                                        for st in &spec_type_values {
1787                                            for &dt in &draft_tokens_values {
1788                                                combinations.push(BenchTuneParamValue {
1789                                                    temperature: temp,
1790                                                    top_p,
1791                                                    top_k,
1792                                                    repeat_penalty: rp,
1793                                                    context_length: None,
1794                                                    batch_size: bs,
1795                                                    flash_attn: fa,
1796                                                    threads: th,
1797                                                    expert_count: ec,
1798                                                    spec_type: st.clone(),
1799                                                    draft_tokens: dt,
1800                                                });
1801                                            }
1802                                        }
1803                                    }
1804                                }
1805                            }
1806                        }
1807                    }
1808                }
1809            }
1810        }
1811
1812        combinations
1813    }
1814
1815    /// Get total number of tests to run
1816    pub fn get_total_tests_count(&self) -> usize {
1817        self.generate_combinations().len()
1818    }
1819
1820    /// Get number of parameter combinations (fast, without generating them)
1821    pub fn get_num_combinations(&self) -> usize {
1822        let mut count: u64 = 1;
1823        for p in &self.params_to_test {
1824            if !p.enabled {
1825                continue;
1826            }
1827            let vals = if p.name == "flash_attn" {
1828                2 // On/Off
1829            } else if !p.variants.is_empty() {
1830                p.variants.len()
1831            } else if p.name == "spec_type" {
1832                let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1833                (step_count + 1).min(9)
1834            } else {
1835                let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1836                step_count + 1
1837            };
1838            count *= vals as u64;
1839        }
1840        count as usize
1841    }
1842}
1843
1844// ── Parameter struct field count tests ──────────────────────────
1845
1846#[cfg(test)]
1847mod field_count_tests {
1848    use super::*;
1849
1850    /// Verify ModelSettings has the expected number of fields.
1851    /// If this test fails, a field was added/removed — update the checklist
1852    /// in src/models.rs:665 and all locations listed there.
1853    #[test]
1854    fn test_model_settings_field_count() {
1855        // This test uses reflection-like field access to verify the count.
1856        // If a field is added/removed, the tuple size changes and the
1857        // expected count assertion fails.
1858        let s = ModelSettings::default();
1859        let field_count = count_model_settings_fields(&s);
1860        assert_eq!(
1861            field_count, 75,
1862            "ModelSettings has {} fields (expected 75). \
1863        Update the checklist at src/models.rs:665 and all locations listed there.",
1864            field_count
1865        );
1866    }
1867
1868    /// Count fields in ModelSettings by forcing reference to each one.
1869    /// This is a compile-time guarantee: if a field is removed, the
1870    /// function won't compile. If a field is added, the count changes.
1871    #[allow(clippy::too_many_lines)]
1872    fn count_model_settings_fields(s: &ModelSettings) -> usize {
1873        let _ = (
1874            &s.context_length,
1875            &s.threads,
1876            &s.threads_batch,
1877            &s.batch_size,
1878            &s.ubatch_size,
1879            &s.parallel,
1880            &s.max_concurrent_predictions,
1881            &s.uniform_cache,
1882            &s.kv_cache_offload,
1883            &s.cache_type_k,
1884            &s.cache_type_v,
1885            &s.keep,
1886            &s.swa_full,
1887            &s.mlock,
1888            &s.mmap,
1889            &s.numa,
1890            &s.system_prompt,
1891            &s.system_prompt_preset_name,
1892            &s.gpu_layers_mode,
1893            &s.split_mode,
1894            &s.tensor_split,
1895            &s.main_gpu,
1896            &s.fit,
1897            &s.lora,
1898            &s.lora_scaled,
1899            &s.rpc,
1900            &s.embedding,
1901            &s.flash_attn,
1902            &s.expert_count,
1903            &s.jinja,
1904            &s.chat_template,
1905            &s.chat_template_kwargs,
1906            &s.seed,
1907            &s.temperature,
1908            &s.top_k,
1909            &s.top_p,
1910            &s.min_p,
1911            &s.typical_p,
1912            &s.mirostat,
1913            &s.mirostat_lr,
1914            &s.mirostat_ent,
1915            &s.ignore_eos,
1916            &s.samplers,
1917            &s.repeat_penalty,
1918            &s.repeat_last_n,
1919            &s.presence_penalty,
1920            &s.frequency_penalty,
1921            &s.dry_multiplier,
1922            &s.dry_base,
1923            &s.dry_allowed_length,
1924            &s.dry_penalty_last_n,
1925            &s.rope_scaling,
1926            &s.rope_scale,
1927            &s.rope_freq_base,
1928            &s.rope_freq_scale,
1929            &s.rope_yarn_enabled,
1930            &s.host,
1931            &s.port,
1932            &s.timeout,
1933            &s.cache_prompt,
1934            &s.cache_reuse,
1935            &s.webui,
1936            &s.max_tokens,
1937            &s.cache_type,
1938            &s.backend,
1939            &s.llama_cpp_version_cpu,
1940            &s.llama_cpp_version_vulkan,
1941            &s.llama_cpp_version_rocm,
1942            &s.llama_cpp_version_rocm_lemonade,
1943            &s.llama_cpp_version_cuda,
1944            &s.api_endpoint_enabled,
1945            &s.api_endpoint_port,
1946            &s.spec_type,
1947            &s.draft_tokens,
1948            &s.tags,
1949        );
1950        75
1951    }
1952
1953    /// Verify that is_dirty() uses derived PartialEq (compiler-enforced).
1954    /// This test confirms that two identical settings are not dirty,
1955    /// and two different settings are dirty.
1956    #[test]
1957    fn test_is_dirty_uses_derived_eq() {
1958        let s1 = ModelSettings::default();
1959        let s2 = ModelSettings::default();
1960        let s3 = s1.clone();
1961
1962        // Identical settings should not be dirty
1963        assert!(!s1.is_dirty(&s2));
1964        assert!(!s1.is_dirty(&s3));
1965
1966        // Derived PartialEq must match is_dirty
1967        assert_eq!(s1 != s2, s1.is_dirty(&s2));
1968        assert_eq!(s1 != s3, s1.is_dirty(&s3));
1969    }
1970
1971    /// Verify DefaultParams and ModelSettings share the same field set
1972    /// via the From<DefaultParams> for ModelSettings implementation.
1973    #[test]
1974    fn test_from_default_params_completeness() {
1975        let dp = crate::config::DefaultParams::default();
1976        let ms: ModelSettings = dp.clone().into();
1977
1978        // Verify the From impl produces a valid ModelSettings
1979        // The derived PartialEq ensures all fields were mapped
1980        assert_eq!(ms.context_length, 131072);
1981        assert_eq!(ms.threads, dp.threads);
1982        assert_eq!(ms.temperature, 0.8);
1983        // backend is hardware-dependent in DefaultParams::default(),
1984        // so we just verify it was mapped correctly
1985        assert_eq!(ms.backend, dp.backend);
1986    }
1987}