1use std::path::PathBuf;
2use std::time::Duration;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq)]
8pub enum ModelState {
9 Available,
10 Loading,
11 Benchmarking,
12 Loaded { port: u16, pid: u32 },
13 Failed { error: String },
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum SearchSort {
19 Relevance,
20 Downloads,
21 Likes,
22 Trending,
23 CreatedAt,
24}
25
26impl SearchSort {
27 pub fn next(self) -> Self {
28 match self {
29 SearchSort::Relevance => SearchSort::Downloads,
30 SearchSort::Downloads => SearchSort::Likes,
31 SearchSort::Likes => SearchSort::Trending,
32 SearchSort::Trending => SearchSort::CreatedAt,
33 SearchSort::CreatedAt => SearchSort::Relevance,
34 }
35 }
36
37 pub fn label(self) -> &'static str {
38 match self {
39 SearchSort::Relevance => "Relevance",
40 SearchSort::Downloads => "Downloads",
41 SearchSort::Likes => "Likes",
42 SearchSort::Trending => "Trending",
43 SearchSort::CreatedAt => "Created",
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
50pub struct SearchResult {
51 pub model_id: String,
52 pub model_name: String,
53 pub tags: Vec<String>,
54 pub downloads: u64,
55 pub likes: u64,
56 pub pipeline_tag: Option<String>,
57 pub size: Option<u64>,
58 pub parameters: Option<String>,
59 pub capabilities: Vec<String>,
60 pub context_length: Option<u32>,
61 pub readme: Option<String>,
62 pub quantization: Option<String>,
64 pub license: Option<String>,
66 pub trending_score: i64,
68 pub created_at: Option<String>,
70 #[serde(default)]
72 pub downloaded: bool,
73}
74
75#[derive(Debug, Clone)]
77pub struct DownloadState {
78 pub model_id: String,
79 pub filename: String,
80 pub total_bytes: u64,
81 pub downloaded_bytes: u64,
82 pub status: DownloadStatus,
83 pub cancelled: bool,
84 pub cancel_token: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
85 pub download_state: u8,
87 pub download_state_arc: Option<std::sync::Arc<std::sync::atomic::AtomicU8>>,
89 pub start_time: std::time::Instant,
90 pub bytes_per_second: f64,
91 pub dest: Option<std::path::PathBuf>,
93}
94
95impl DownloadState {
96 pub fn new(model_id: String, filename: String, total_bytes: u64) -> Self {
97 Self {
98 model_id,
99 filename,
100 total_bytes,
101 downloaded_bytes: 0,
102 status: DownloadStatus::Downloading,
103 cancelled: false,
104 cancel_token: None,
105 download_state: 1,
106 download_state_arc: None,
107 start_time: std::time::Instant::now(),
108 bytes_per_second: 0.0,
109 dest: None,
110 }
111 }
112}
113
114impl ModelSettings {
115 pub fn get_active_backend_version(&self) -> Option<&String> {
117 match self.backend {
118 Backend::Cpu => self.llama_cpp_version_cpu.as_ref(),
119 Backend::Vulkan => self.llama_cpp_version_vulkan.as_ref(),
120 Backend::Rocm => self.llama_cpp_version_rocm.as_ref(),
121 Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade.as_ref(),
122 Backend::Cuda => self.llama_cpp_version_cuda.as_ref(),
123 _ => None,
124 }
125 }
126
127 pub fn get_active_backend_version_display(&self) -> &str {
129 self.get_active_backend_version()
130 .map(|s| s.as_str())
131 .unwrap_or("latest")
132 }
133
134 pub fn set_active_backend_version(&mut self, tag: Option<String>) {
136 match self.backend {
137 Backend::Cpu => self.llama_cpp_version_cpu = tag,
138 Backend::Vulkan => self.llama_cpp_version_vulkan = tag,
139 Backend::Rocm => self.llama_cpp_version_rocm = tag,
140 Backend::RocmLemonade => self.llama_cpp_version_rocm_lemonade = tag,
141 Backend::Cuda => self.llama_cpp_version_cuda = tag,
142 _ => {}
143 }
144 }
145}
146
147pub fn strip_gguf(name: &str) -> &str {
149 name.strip_suffix(".gguf")
150 .or_else(|| name.strip_suffix(".GGUF"))
151 .unwrap_or(name)
152}
153
154pub fn clean_host(host: &str) -> String {
158 let host = host.trim();
159 if host.is_empty() {
160 return "127.0.0.1".to_string();
161 }
162 let host = host.split_whitespace().next().unwrap_or(host);
164 if host.contains(':') && !host.starts_with('[') {
165 format!("[{}]", host)
166 } else {
167 host.to_string()
168 }
169}
170
171pub fn format_host(host: &str) -> &str {
173 match host {
174 "" | "127.0.0.1" => "localhost (127.0.0.1)",
175 _ => host,
176 }
177}
178
179impl From<crate::config::DefaultParams> for ModelSettings {
180 fn from(dp: crate::config::DefaultParams) -> Self {
181 Self {
182 context_length: dp.context_length,
183 threads: dp.threads,
184 threads_batch: dp.threads_batch,
185 batch_size: dp.batch_size,
186 ubatch_size: dp.ubatch_size,
187 parallel: dp.parallel,
188 max_concurrent_predictions: dp.max_concurrent_predictions,
189 uniform_cache: dp.uniform_cache,
190 kv_cache_offload: dp.kv_cache_offload,
191 cache_type_k: dp.cache_type_k,
192 cache_type_v: dp.cache_type_v,
193 keep: dp.keep,
194 swa_full: dp.swa_full,
195 mlock: dp.mlock,
196 mmap: dp.mmap,
197 numa: dp.numa,
198 system_prompt: dp.system_prompt,
199 system_prompt_preset_name: dp.system_prompt_preset_name,
200
201 gpu_layers_mode: match dp.gpu_layers {
202 n if n < 0 => GpuLayersMode::All,
203 _ => dp.gpu_layers_mode,
204 },
205 split_mode: dp.split_mode,
206 tensor_split: dp.tensor_split,
207 main_gpu: dp.main_gpu,
208 fit: dp.fit,
209 lora: dp.lora,
210 lora_scaled: dp.lora_scaled,
211 rpc: dp.rpc,
212 embedding: dp.embedding,
213 flash_attn: dp.flash_attn,
214 expert_count: dp.expert_count,
215 jinja: dp.jinja,
216 chat_template: dp.chat_template,
217 chat_template_kwargs: dp.chat_template_kwargs,
218 seed: dp.seed,
219 temperature: dp.temperature,
220 top_k: dp.top_k,
221 top_p: dp.top_p,
222 min_p: dp.min_p,
223 typical_p: dp.typical_p,
224 mirostat: dp.mirostat,
225 mirostat_lr: dp.mirostat_lr,
226 mirostat_ent: dp.mirostat_ent,
227 ignore_eos: dp.ignore_eos,
228 samplers: dp.samplers,
229 repeat_penalty: dp.repeat_penalty,
230 repeat_last_n: dp.repeat_last_n,
231 presence_penalty: dp.presence_penalty,
232 frequency_penalty: dp.frequency_penalty,
233 dry_multiplier: dp.dry_multiplier,
234 dry_base: dp.dry_base,
235 dry_allowed_length: dp.dry_allowed_length,
236 dry_penalty_last_n: dp.dry_penalty_last_n,
237 rope_scaling: dp.rope_scaling,
238 rope_scale: dp.rope_scale,
239 rope_freq_base: dp.rope_freq_base,
240 rope_freq_scale: dp.rope_freq_scale,
241 rope_yarn_enabled: dp.rope_yarn_enabled,
242 host: dp.host,
243 port: dp.port,
244 timeout: dp.timeout,
245 cache_prompt: dp.cache_prompt,
246 cache_reuse: dp.cache_reuse,
247 webui: dp.webui,
248 max_tokens: dp.max_tokens,
249 cache_type: dp.cache_type,
250 backend: dp.backend,
251 llama_cpp_version_cpu: dp.llama_cpp_version_cpu,
252 llama_cpp_version_vulkan: dp.llama_cpp_version_vulkan,
253 llama_cpp_version_rocm: dp.llama_cpp_version_rocm,
254 llama_cpp_version_rocm_lemonade: dp.llama_cpp_version_rocm_lemonade,
255 llama_cpp_version_cuda: dp.llama_cpp_version_cuda,
256 api_endpoint_enabled: dp.api_endpoint_enabled,
257 api_endpoint_port: dp.api_endpoint_port,
258 ws_server_enabled: dp.ws_server_enabled,
259 ws_server_port: dp.ws_server_port,
260 ws_server_auth_key: dp.ws_server_auth_key,
261 ws_server_tls_enabled: dp.ws_server_tls_enabled,
262 ws_server_tls_cert: dp.ws_server_tls_cert,
263 ws_server_tls_key: dp.ws_server_tls_key,
264 spec_type: dp.spec_type,
265 draft_tokens: dp.draft_tokens,
266 tags: dp.tags,
267 }
268 }
269}
270
271#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Hash)]
273#[derive(Default)]
274pub enum GpuLayersMode {
275 #[default]
276 Auto,
277 Specific(u32),
278 All,
279}
280
281
282#[derive(Debug, Clone, PartialEq, Eq)]
283pub enum DownloadStatus {
284 Downloading,
285 Paused,
286 Complete,
287 Error(String),
288 Cancelled,
289}
290
291#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
295pub enum CacheType {
296 #[serde(rename = "f16")]
297 #[default]
298 F16,
299 #[serde(rename = "bf16")]
300 BF16,
301 #[serde(rename = "fq8_0")]
302 Fq8_0,
303 #[serde(rename = "fq4_1")]
304 Fq4_1,
305}
306
307impl std::fmt::Display for CacheType {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 match self {
310 CacheType::F16 => write!(f, "f16"),
311 CacheType::BF16 => write!(f, "bf16"),
312 CacheType::Fq8_0 => write!(f, "fq8_0"),
313 CacheType::Fq4_1 => write!(f, "fq4_1"),
314 }
315 }
316}
317
318#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default, Hash)]
320pub enum CacheQuantType {
321 #[serde(rename = "f32")]
322 F32,
323 #[serde(rename = "f16")]
324 #[default]
325 F16,
326 #[serde(rename = "bf16")]
327 BF16,
328 #[serde(rename = "q8_0")]
329 Q8_0,
330 #[serde(rename = "q4_0")]
331 Q4_0,
332 #[serde(rename = "q4_1")]
333 Q4_1,
334 #[serde(rename = "iq4_nl")]
335 Iq4Nl,
336 #[serde(rename = "q5_0")]
337 Q5_0,
338 #[serde(rename = "q5_1")]
339 Q5_1,
340}
341
342pub type CacheTypeK = CacheQuantType;
343pub type CacheTypeV = CacheQuantType;
344
345impl CacheQuantType {
346 pub fn from_u8(n: u8) -> Self {
347 match n {
348 0 => Self::F32,
349 1 => Self::F16,
350 2 => Self::BF16,
351 3 => Self::Q8_0,
352 4 => Self::Q5_1,
353 5 => Self::Q5_0,
354 6 => Self::Q4_1,
355 7 => Self::Q4_0,
356 8 => Self::Iq4Nl,
357 _ => Self::F16,
358 }
359 }
360 pub fn next(&self) -> Self {
361 match self {
362 Self::F32 => Self::F16,
363 Self::F16 => Self::BF16,
364 Self::BF16 => Self::Q8_0,
365 Self::Q8_0 => Self::Q5_1,
366 Self::Q5_1 => Self::Q5_0,
367 Self::Q5_0 => Self::Q4_1,
368 Self::Q4_1 => Self::Q4_0,
369 Self::Q4_0 => Self::Iq4Nl,
370 Self::Iq4Nl => Self::F32,
371 }
372 }
373 pub fn prev(&self) -> Self {
374 match self {
375 Self::F32 => Self::Iq4Nl,
376 Self::F16 => Self::F32,
377 Self::BF16 => Self::F16,
378 Self::Q8_0 => Self::BF16,
379 Self::Q5_1 => Self::Q8_0,
380 Self::Q5_0 => Self::Q5_1,
381 Self::Q4_1 => Self::Q5_0,
382 Self::Q4_0 => Self::Q4_1,
383 Self::Iq4Nl => Self::Q4_0,
384 }
385 }
386}
387
388impl std::fmt::Display for CacheQuantType {
389 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
390 match self {
391 Self::F32 => write!(f, "f32"),
392 Self::F16 => write!(f, "f16"),
393 Self::BF16 => write!(f, "bf16"),
394 Self::Q8_0 => write!(f, "q8_0"),
395 Self::Q4_0 => write!(f, "q4_0"),
396 Self::Q4_1 => write!(f, "q4_1"),
397 Self::Iq4Nl => write!(f, "iq4_nl"),
398 Self::Q5_0 => write!(f, "q5_0"),
399 Self::Q5_1 => write!(f, "q5_1"),
400 }
401 }
402}
403
404impl From<&str> for CacheQuantType {
405 fn from(s: &str) -> Self {
406 match s {
407 "F32" => Self::F32,
408 "F16" => Self::F16,
409 "BF16" => Self::BF16,
410 "Q8_0" => Self::Q8_0,
411 "Q4_0" => Self::Q4_0,
412 "Q4_1" => Self::Q4_1,
413 "Iq4Nl" => Self::Iq4Nl,
414 "Q5_0" => Self::Q5_0,
415 "Q5_1" => Self::Q5_1,
416 _ => Self::F16, }
418 }
419}
420
421#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
423pub enum SplitMode {
424 #[serde(rename = "none")]
425 None,
426 #[serde(rename = "layer")]
427 #[default]
428 Layer,
429 #[serde(rename = "row")]
430 Row,
431 #[serde(rename = "tensor")]
432 Tensor,
433}
434
435impl std::fmt::Display for SplitMode {
436 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437 match self {
438 SplitMode::None => write!(f, "none"),
439 SplitMode::Layer => write!(f, "layer"),
440 SplitMode::Row => write!(f, "row"),
441 SplitMode::Tensor => write!(f, "tensor"),
442 }
443 }
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Hash, Default)]
448pub enum NumMode {
449 #[serde(rename = "none")]
450 #[default]
451 None,
452 #[serde(rename = "distribute")]
453 Distribute,
454 #[serde(rename = "isolate")]
455 Isolate,
456 #[serde(rename = "numactl")]
457 Numactl,
458}
459
460impl std::fmt::Display for NumMode {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 match self {
463 NumMode::None => write!(f, "none"),
464 NumMode::Distribute => write!(f, "distribute"),
465 NumMode::Isolate => write!(f, "isolate"),
466 NumMode::Numactl => write!(f, "numactl"),
467 }
468 }
469}
470
471#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
473pub enum RopeScaling {
474 #[serde(rename = "none")]
475 #[default]
476 None,
477 #[serde(rename = "linear")]
478 Linear,
479 #[serde(rename = "yarn")]
480 Yarn,
481}
482
483impl std::fmt::Display for RopeScaling {
484 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485 match self {
486 RopeScaling::None => write!(f, "none"),
487 RopeScaling::Linear => write!(f, "linear"),
488 RopeScaling::Yarn => write!(f, "yarn"),
489 }
490 }
491}
492
493#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash, Default)]
495pub enum Mirostat {
496 #[serde(rename = "0")]
497 #[default]
498 Off,
499 #[serde(rename = "1")]
500 V1,
501 #[serde(rename = "2")]
502 Mirostat2,
503}
504
505impl std::fmt::Display for Mirostat {
506 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
507 match self {
508 Mirostat::Off => write!(f, "off"),
509 Mirostat::V1 => write!(f, "1"),
510 Mirostat::Mirostat2 => write!(f, "2"),
511 }
512 }
513}
514
515#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
518pub struct Samplers(pub String);
519
520impl Default for Samplers {
521 fn default() -> Self {
522 Self("penalties;dry;top_n_sigma;top_k;typ_p;top_p;min_p;xtc;temperature".to_string())
523 }
524}
525
526impl std::fmt::Display for Samplers {
527 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
528 write!(f, "{}", self.0)
529 }
530}
531
532#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
534pub enum Backend {
535 #[serde(rename = "cpu")]
536 #[default]
537 Cpu,
538 #[serde(rename = "vulkan")]
539 Vulkan,
540 #[serde(rename = "rocm")]
541 Rocm,
542 #[serde(rename = "rocm_lemonade")]
543 RocmLemonade,
544 #[serde(rename = "cuda")]
545 Cuda,
546 #[serde(rename = "cpu_arm64")]
547 CpuArm64,
548 #[serde(rename = "win_cpu")]
549 CpuWindows,
550 #[serde(rename = "win_vulkan")]
551 VulkanWindows,
552 #[serde(rename = "win_cuda_12_4")]
553 CudaWindows12_4,
554 #[serde(rename = "win_cuda_13_1")]
555 CudaWindows13_1,
556 #[serde(rename = "win_hip")]
557 HipWindows,
558 #[serde(rename = "macos_arm64")]
559 CpuMacosArm64,
560 #[serde(rename = "macos_x64")]
561 CpuMacosX64,
562}
563
564impl Backend {
565 pub fn slug(&self) -> &'static str {
567 match self {
568 Backend::Cpu => "cpu",
569 Backend::Vulkan => "vulkan",
570 Backend::Rocm => "rocm",
571 Backend::RocmLemonade => "rocm-lemonade",
572 Backend::Cuda => "cuda",
573 Backend::CpuArm64 => "cpu-arm64",
574 Backend::CpuWindows => "win-cpu",
575 Backend::VulkanWindows => "win-vulkan",
576 Backend::CudaWindows12_4 => "win-cuda-12.4",
577 Backend::CudaWindows13_1 => "win-cuda-13.1",
578 Backend::HipWindows => "win-hip",
579 Backend::CpuMacosArm64 => "macos-arm64",
580 Backend::CpuMacosX64 => "macos-x64",
581 }
582 }
583
584 pub fn is_linux(self) -> bool {
586 matches!(
587 self,
588 Backend::Cpu
589 | Backend::Vulkan
590 | Backend::Rocm
591 | Backend::RocmLemonade
592 | Backend::Cuda
593 | Backend::CpuArm64
594 )
595 }
596
597 pub fn is_windows(self) -> bool {
599 matches!(
600 self,
601 Backend::CpuWindows
602 | Backend::VulkanWindows
603 | Backend::CudaWindows12_4
604 | Backend::CudaWindows13_1
605 | Backend::HipWindows
606 )
607 }
608
609 pub fn is_macos(self) -> bool {
611 matches!(self, Backend::CpuMacosArm64 | Backend::CpuMacosX64)
612 }
613
614 pub fn from_str(s: &str) -> Self {
616 let s = s.to_lowercase();
617 if s.starts_with("vulkan") || s.starts_with("vk") {
618 Backend::Vulkan
619 } else if s.starts_with("rocm") || s.starts_with("ro") {
620 if s.contains("lemonade") {
621 Backend::RocmLemonade
622 } else {
623 Backend::Rocm
624 }
625 } else if s.starts_with("cuda") || s.starts_with("cu") {
626 Backend::Cuda
627 } else {
628 Backend::Cpu }
630 }
631}
632
633impl std::fmt::Display for Backend {
634 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
635 write!(f, "{}", self.slug())
636 }
637}
638
639#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
641pub enum ServerMode {
642 #[serde(rename = "normal")]
643 #[default]
644 Normal,
645 #[serde(rename = "router")]
646 Router,
647 #[serde(rename = "bench_gpu", alias = "bench")]
648 Bench,
649 #[serde(rename = "bench_tune")]
650 BenchTune,
651}
652
653impl std::fmt::Display for ServerMode {
654 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
655 match self {
656 ServerMode::Normal => write!(f, "Normal"),
657 ServerMode::Router => write!(f, "Router (XP!)"),
658 ServerMode::Bench => write!(f, "Bench GPU"),
659 ServerMode::BenchTune => write!(f, "BenchTune"),
660 }
661 }
662}
663
664#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
686pub struct ModelSettings {
687 pub context_length: u32,
690 pub threads: u32,
692 pub threads_batch: u32,
694 pub batch_size: u32,
696 pub ubatch_size: u32,
698 pub parallel: u32,
700 pub max_concurrent_predictions: Option<u32>,
702 pub uniform_cache: bool,
704 pub kv_cache_offload: bool,
706 pub cache_type_k: Option<CacheTypeK>,
708 pub cache_type_v: Option<CacheTypeV>,
710 pub keep: i32,
712 pub swa_full: bool,
714 pub mlock: bool,
716 pub mmap: bool,
718 pub numa: NumMode,
720 pub system_prompt: String,
722 pub system_prompt_preset_name: String,
724
725 pub gpu_layers_mode: GpuLayersMode,
728 pub split_mode: SplitMode,
730 pub tensor_split: String,
732 pub main_gpu: i32,
734 pub fit: bool,
736 pub lora: Option<PathBuf>,
738 pub lora_scaled: Option<(PathBuf, f32)>,
740 pub rpc: String,
742 pub embedding: bool,
744 pub flash_attn: bool,
746 pub expert_count: i32,
748 pub jinja: bool,
750 pub chat_template: Option<String>,
752 pub chat_template_kwargs: Option<String>,
754
755 pub seed: i32,
758 pub temperature: f32,
760 pub top_k: i32,
762 pub top_p: f32,
764 pub min_p: f32,
766 pub typical_p: f32,
768 pub mirostat: Mirostat,
770 pub mirostat_lr: f32,
772 pub mirostat_ent: f32,
774 pub ignore_eos: bool,
776 pub samplers: Samplers,
778
779 pub repeat_penalty: f32,
782 pub repeat_last_n: i32,
784 pub presence_penalty: Option<f32>,
786 pub frequency_penalty: Option<f32>,
788 pub dry_multiplier: f32,
790 pub dry_base: f32,
792 pub dry_allowed_length: i32,
794 pub dry_penalty_last_n: i32,
796
797 pub rope_scaling: RopeScaling,
800 pub rope_scale: f32,
802 pub rope_freq_base: f32,
804 pub rope_freq_scale: f32,
806 pub rope_yarn_enabled: bool,
808
809 pub host: String,
812 pub port: u16,
814 pub timeout: u32,
816 pub cache_prompt: bool,
818 pub cache_reuse: u32,
820 pub webui: bool,
822
823 pub max_tokens: Option<u32>,
826 pub cache_type: CacheType,
828 pub backend: Backend,
830 pub llama_cpp_version_cpu: Option<String>,
832 pub llama_cpp_version_vulkan: Option<String>,
834 pub llama_cpp_version_rocm: Option<String>,
836 pub llama_cpp_version_rocm_lemonade: Option<String>,
838 pub llama_cpp_version_cuda: Option<String>,
840 pub api_endpoint_enabled: bool,
842 pub api_endpoint_port: u16,
844 pub spec_type: String,
846 pub draft_tokens: u32,
848 pub tags: Vec<String>,
850 pub ws_server_enabled: bool,
852 pub ws_server_port: u16,
853 pub ws_server_auth_key: Option<String>,
854 pub ws_server_tls_enabled: bool,
855 pub ws_server_tls_cert: Option<String>,
856 pub ws_server_tls_key: Option<String>,
857}
858
859impl Default for ModelSettings {
860 fn default() -> Self {
861 let mut s: Self = crate::config::DefaultParams::default().into();
862 s.uniform_cache = false;
864 s.cache_type_k = Some(CacheTypeK::F16);
865 s.cache_type_v = Some(CacheTypeV::F16);
866 s.cache_type = CacheType::default();
867 s.backend = Backend::Cpu;
868 s.presence_penalty = Some(0.0);
869 s.frequency_penalty = Some(0.0);
870 s
871 }
872}
873
874impl ModelSettings {
875 pub fn from_config(config: &crate::config::Config) -> Self {
877 let mut settings: ModelSettings = config.default.clone().into();
878 settings.ws_server_enabled = config.ws_server.enabled;
879 settings.ws_server_port = config.ws_server.port;
880 settings.ws_server_auth_key = config.ws_server.auth_key.clone();
881 settings.ws_server_tls_enabled = config.ws_server.tls_enabled;
882 settings.ws_server_tls_cert = config.ws_server.tls_cert.clone();
883 settings.ws_server_tls_key = config.ws_server.tls_key.clone();
884 settings
885 }
886}
887
888pub const BENCHMARK_PROMPT: &str = "Create Mona Lisa image in ascii art using text, number, symbol, everything possible. this should be the perfect painting.";
890
891#[derive(Debug, Clone)]
893pub struct DiscoveredModel {
894 pub path: PathBuf,
895 pub name: String,
896 pub file_size: u64,
897 pub display_name: String, }
899
900#[derive(Debug, Clone, Default)]
902pub struct GgufMetadata {
903 pub layers: u32,
904 pub hidden_size: u32,
905 pub n_ctx_train: u32,
906 pub n_head: u32,
907 pub n_kv_head: u32,
908 pub arch: String,
909 pub file_type: String,
910 pub quantization: String,
911 pub model_parameters: String,
912 pub domain: String,
913 pub capabilities: Vec<String>,
914 pub tokenizer: String,
915 pub vocab_size: u32,
916 pub draft_tokens: u32,
917}
918
919impl GgufMetadata {
920 pub fn from_path(path: &std::path::Path) -> anyhow::Result<Self> {
921 let path_str = path.to_string_lossy();
922 let mut container = gguf_rs::get_gguf_container(&path_str)
923 .map_err(|e| anyhow::anyhow!("Failed to get GGUF container: {}", e))?;
924 let model_data = container
925 .decode()
926 .map_err(|e| anyhow::anyhow!("Failed to decode GGUF: {}", e))?;
927
928 let mut meta = Self::default();
929
930 let extract_str = |key: &str| -> String {
931 model_data
932 .metadata()
933 .get(key)
934 .and_then(|v| v.as_str().map(|s| s.to_string()))
935 .unwrap_or_default()
936 };
937
938 let extract_num = |key: &str| -> Option<u64> {
939 model_data.metadata().get(key).and_then(|v| {
940 v.as_u64()
941 .or_else(|| v.as_i64().map(|x| x as u64))
942 .or_else(|| v.as_f64().map(|x| x as u64))
943 })
944 };
945
946 meta.arch = extract_str("general.architecture");
947 let prefix = if meta.arch.is_empty() {
948 "llama"
949 } else {
950 &meta.arch
951 };
952
953 let get_num_with_fallback = |suffix: &str| -> u32 {
954 extract_num(&format!("{}.{}", prefix, suffix))
955 .or_else(|| {
956 if prefix != "llama" {
957 extract_num(&format!("llama.{}", suffix))
958 } else {
959 None
960 }
961 })
962 .unwrap_or(0) as u32
963 };
964
965 meta.layers = get_num_with_fallback("block_count");
966 meta.hidden_size = get_num_with_fallback("embedding_length");
967 meta.n_ctx_train = get_num_with_fallback("context_length");
968 meta.n_head = get_num_with_fallback("attention.head_count");
969 meta.n_kv_head = get_num_with_fallback("attention.head_count_kv");
970
971 if let Some(value) = model_data.metadata().get("tokenizer.ggml.tokens")
972 && let Some(arr) = value.as_array()
973 {
974 meta.vocab_size = arr.len() as u32;
975 }
976
977 if meta.arch == "mtp" {
978 meta.draft_tokens = extract_num("mtp.draft_tokens").unwrap_or(0) as u32;
979 }
980
981 if let Some(v) = extract_num("general.file_type") {
982 meta.file_type = match v {
983 0 => "F32".to_string(),
984 1 => "F16".to_string(),
985 2 => "Q4_0".to_string(),
986 3 => "Q4_1".to_string(),
987 7 => "Q8_0".to_string(),
988 8 => "Q5_0".to_string(),
989 9 => "Q5_1".to_string(),
990 10 => "Q2_K".to_string(),
991 11 => "Q3_K_S".to_string(),
992 12 => "Q3_K_M".to_string(),
993 13 => "Q3_K_L".to_string(),
994 14 => "Q4_K_S".to_string(),
995 15 => "Q4_K_M".to_string(),
996 16 => "Q5_K_S".to_string(),
997 17 => "Q5_K_M".to_string(),
998 18 => "Q6_K".to_string(),
999 19 => "IQ2_XXS".to_string(),
1000 20 => "IQ2_XS".to_string(),
1001 21 => "IQ3_XXS".to_string(),
1002 22 => "IQ1_S".to_string(),
1003 23 => "IQ4_NL".to_string(),
1004 24 => "IQ3_S".to_string(),
1005 25 => "IQ2_S".to_string(),
1006 26 => "IQ4_XS".to_string(),
1007 _ => format!("Unknown ({})", v),
1008 };
1009 }
1010
1011 if let Some(value) = model_data.metadata().get("general.capabilities")
1012 && let Some(arr) = value.as_array()
1013 {
1014 for v in arr {
1015 if let Some(s) = v.as_str() {
1016 meta.capabilities.push(s.to_string());
1017 }
1018 }
1019 }
1020
1021 if model_data
1022 .metadata()
1023 .contains_key("tokenizer.chat_template")
1024 {
1025 meta.capabilities.push("chat".to_string());
1026 }
1027
1028 meta.tokenizer = extract_str("tokenizer.ggml.model");
1029 meta.domain = extract_str("general.domain");
1030 meta.model_parameters = model_data.model_parameters();
1031
1032 Ok(meta)
1033 }
1034}
1035
1036#[derive(Debug, Clone)]
1038pub struct ServerMetrics {
1039 pub loaded: bool,
1040 pub tps: f64,
1041 pub prompt_tps: f64,
1042 pub cpu_usage: f64,
1043 pub gpu_mem_used: u64,
1044 pub gpu_mem_total: u64,
1045 pub ram_used: u64,
1046 pub ctx_used: u32,
1047 pub ctx_max: u32,
1048 pub total_vram_used: u64,
1050 pub decoded_tokens: u64,
1052 pub gen_tps: f64,
1054 pub latency_per_token_ms: f64,
1056 pub prompt_latency_ms: f64,
1058}
1059
1060#[derive(Debug, Clone)]
1062pub struct GPUBuffer {
1063 pub device: String,
1064 pub buffer_size_mib: f64,
1065}
1066
1067#[derive(Debug, Clone, Default)]
1069pub struct LoadProgress {
1070 pub layers_total: Option<u32>,
1072 pub layers_loaded: Option<u32>,
1074 pub tensors_total: Option<u32>,
1076 pub tensors_loaded: u32,
1078 pub buffers: Vec<GPUBuffer>,
1080}
1081
1082impl Default for ServerMetrics {
1083 fn default() -> Self {
1084 Self {
1085 loaded: false,
1086 tps: 0.0,
1087 prompt_tps: 0.0,
1088 cpu_usage: 0.0,
1089 gpu_mem_used: 0,
1090 gpu_mem_total: 0,
1091 ram_used: 0,
1092 ctx_used: 0,
1093 ctx_max: 0,
1094 total_vram_used: 0,
1095 decoded_tokens: 0,
1096 gen_tps: 0.0,
1097 latency_per_token_ms: 0.0,
1098 prompt_latency_ms: 0.0,
1099 }
1100 }
1101}
1102
1103#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1105pub struct WsMetrics {
1106 pub model_name: String,
1107 pub loaded: bool,
1108 pub state: String,
1109 pub tps: f64,
1110 pub prompt_tps: f64,
1111 pub ctx_used: u32,
1112 pub ctx_max: u32,
1113 pub cpu_usage: f64,
1114 pub gpu_mem_used: u64,
1115 pub gpu_mem_total: u64,
1116 pub ram_used: u64,
1117 pub latency_per_token_ms: f64,
1118 pub decoded_tokens: u64,
1119 pub gen_tps: f64,
1120 pub timestamp: u64,
1121 pub cmd_display: Option<String>,
1123 pub threads: u32,
1125 pub threads_batch: u32,
1126 pub context_length: u32,
1127 pub ubatch_size: u32,
1128 pub batch_size: u32,
1129 pub temperature: f32,
1130 pub top_k: u32,
1131 pub top_p: f32,
1132 pub min_p: f32,
1133 pub typical_p: f32,
1134 pub seed: i32,
1135 pub repeat_penalty: f32,
1136 pub repeat_last_n: i32,
1137 pub presence_penalty: Option<f32>,
1138 pub frequency_penalty: Option<f32>,
1139 pub mirostat: Option<u32>,
1140 pub mirostat_lr: Option<f32>,
1141 pub mirostat_ent: Option<f32>,
1142 pub max_tokens: Option<u32>,
1143 pub flash_attn: bool,
1144 pub kv_cache_offload: bool,
1145 pub cache_type_k: Option<String>,
1146 pub cache_type_v: Option<String>,
1147 pub uniform_cache: bool,
1148 pub mlock: bool,
1149 pub mmap: bool,
1150 pub embedding: bool,
1151 pub jinja: bool,
1152 pub ignore_eos: bool,
1153 pub samplers: String,
1154 pub expert_count: u32,
1155 pub gpu_layers: String,
1156 pub backend: String,
1157 pub llama_cpp_version: String,
1158 pub spec_type: String,
1159 pub draft_tokens: u32,
1160}
1161
1162impl WsMetrics {
1163 pub fn from_metrics(
1164 metrics: &ServerMetrics,
1165 model_name: &str,
1166 state: &str,
1167 settings: &crate::models::ModelSettings,
1168 cmd_display: Option<&str>,
1169 ) -> Self {
1170 use std::time::{SystemTime, UNIX_EPOCH};
1171 let timestamp = SystemTime::now()
1172 .duration_since(UNIX_EPOCH)
1173 .map(|d| d.as_secs())
1174 .unwrap_or(0);
1175 let gpu_layers = match settings.gpu_layers_mode {
1176 crate::models::GpuLayersMode::Auto => "Auto".to_string(),
1177 crate::models::GpuLayersMode::Specific(n) => n.to_string(),
1178 crate::models::GpuLayersMode::All => "All".to_string(),
1179 };
1180 Self {
1181 model_name: model_name.to_string(),
1182 loaded: metrics.loaded,
1183 state: state.to_string(),
1184 tps: metrics.tps,
1185 prompt_tps: metrics.prompt_tps,
1186 ctx_used: metrics.ctx_used,
1187 ctx_max: metrics.ctx_max,
1188 cpu_usage: metrics.cpu_usage,
1189 gpu_mem_used: metrics.gpu_mem_used,
1190 gpu_mem_total: metrics.gpu_mem_total,
1191 ram_used: metrics.ram_used,
1192 latency_per_token_ms: metrics.latency_per_token_ms,
1193 decoded_tokens: metrics.decoded_tokens,
1194 gen_tps: metrics.gen_tps,
1195 timestamp,
1196 cmd_display: cmd_display.map(String::from),
1197 threads: settings.threads,
1198 threads_batch: settings.threads_batch,
1199 context_length: settings.context_length,
1200 ubatch_size: settings.ubatch_size,
1201 batch_size: settings.batch_size,
1202 temperature: settings.temperature,
1203 top_k: settings.top_k as u32,
1204 top_p: settings.top_p,
1205 min_p: settings.min_p,
1206 typical_p: settings.typical_p,
1207 seed: settings.seed,
1208 repeat_penalty: settings.repeat_penalty,
1209 repeat_last_n: settings.repeat_last_n,
1210 presence_penalty: settings.presence_penalty,
1211 frequency_penalty: settings.frequency_penalty,
1212 mirostat: Some(match settings.mirostat {
1213 crate::models::Mirostat::Off => 0,
1214 crate::models::Mirostat::V1 => 1,
1215 crate::models::Mirostat::Mirostat2 => 2,
1216 }),
1217 mirostat_lr: Some(settings.mirostat_lr),
1218 mirostat_ent: Some(settings.mirostat_ent),
1219 max_tokens: settings.max_tokens,
1220 flash_attn: settings.flash_attn,
1221 kv_cache_offload: settings.kv_cache_offload,
1222 cache_type_k: settings.cache_type_k.map(|k| k.to_string()),
1223 cache_type_v: settings.cache_type_v.map(|k| k.to_string()),
1224 uniform_cache: settings.uniform_cache,
1225 mlock: settings.mlock,
1226 mmap: settings.mmap,
1227 embedding: settings.embedding,
1228 jinja: settings.jinja,
1229 ignore_eos: settings.ignore_eos,
1230 samplers: settings.samplers.to_string(),
1231 expert_count: settings.expert_count as u32,
1232 gpu_layers,
1233 backend: settings.backend.to_string(),
1234 llama_cpp_version: settings.get_active_backend_version_display().to_string(),
1235 spec_type: settings.spec_type.clone(),
1236 draft_tokens: settings.draft_tokens,
1237 }
1238 }
1239}
1240
1241pub fn estimate_vram_mib(
1253 model_mib: u64,
1254 settings: &ModelSettings,
1255 total_layers: u32,
1256 hidden_size_opt: Option<u32>,
1257 n_head_opt: Option<u32>,
1258 n_kv_head_opt: Option<u32>,
1259 gpu_mem_total_mib: u64,
1260) -> u64 {
1261 let model_mib_f = model_mib as f64;
1262
1263 let gpu_layers = match settings.gpu_layers_mode {
1265 GpuLayersMode::Auto => {
1266 if total_layers > 0 {
1268 (total_layers as f64 * 0.6) as u32
1269 } else {
1270 20
1271 }
1272 }
1273 GpuLayersMode::Specific(n) => {
1274 if total_layers > 0 {
1275 n.min(total_layers)
1276 } else {
1277 n
1278 }
1279 }
1280 GpuLayersMode::All => {
1281 if total_layers > 0 {
1282 total_layers
1283 } else {
1284 32
1285 }
1286 }
1287 };
1288
1289 let model_vram = if total_layers > 0 && gpu_layers > 0 {
1291 model_mib_f * (gpu_layers as f64 / total_layers as f64).min(1.0)
1292 } else if gpu_layers > 0 {
1293 model_mib_f
1294 } else {
1295 0.0
1296 };
1297
1298 if matches!(settings.gpu_layers_mode, GpuLayersMode::Specific(0)) {
1299 return 0; }
1301
1302 let hidden_size = match hidden_size_opt {
1305 Some(h) => h as f64,
1306 None => {
1307 let params_est = model_mib_f / 550.0;
1308 (1024.0 * params_est.sqrt().max(1.0) * 1.5).max(512.0)
1309 }
1310 };
1311
1312 let gqa_ratio = match (n_head_opt, n_kv_head_opt) {
1318 (Some(n_head), Some(n_kv_head)) if n_head > 0 => n_kv_head as f64 / n_head as f64,
1319 _ => 1.0, };
1321
1322 let flash_attn_factor = if settings.flash_attn { 0.5 } else { 1.0 };
1325
1326 let uniform_cache_factor = if settings.uniform_cache {
1328 1.0 / settings.parallel as f64
1329 } else {
1330 1.0
1331 };
1332
1333 let kv_mib = (2.0
1340 * hidden_size
1341 * settings.context_length as f64
1342 * total_layers as f64
1343 * gqa_ratio
1344 * gpu_layers as f64
1345 / total_layers as f64 * flash_attn_factor
1347 * uniform_cache_factor
1348 * kv_quant_bytes(
1349 settings.cache_type_k.unwrap_or(CacheTypeK::F16),
1350 settings.cache_type_v.unwrap_or(CacheTypeV::F16)
1351 ))
1352 / (1024.0 * 1024.0);
1353
1354 let activation_mib = (settings.batch_size as f64 * hidden_size * 8.0) / (1024.0 * 1024.0);
1357
1358 let fixed_overhead = if gpu_mem_total_mib > 0 {
1361 gpu_mem_total_mib as f64 * 0.038
1362 } else {
1363 500.0
1364 };
1365
1366 let total_mib = model_vram + kv_mib + activation_mib + fixed_overhead + 550.0;
1367
1368 total_mib.ceil() as u64
1369}
1370
1371fn kv_quant_bytes(k_type: CacheQuantType, v_type: CacheQuantType) -> f64 {
1376 let get_bytes = |t: CacheQuantType| match t {
1377 CacheQuantType::F32 => 4.0,
1378 CacheQuantType::F16 | CacheQuantType::BF16 => 2.0,
1379 CacheQuantType::Q8_0 => 1.0,
1380 CacheQuantType::Q5_0 | CacheQuantType::Q5_1 => 0.625, CacheQuantType::Q4_0 | CacheQuantType::Q4_1 | CacheQuantType::Iq4Nl => 0.5, };
1383 (get_bytes(k_type) + get_bytes(v_type)) / 2.0
1384}
1385
1386impl ModelSettings {
1387 pub fn is_dirty(&self, other: &Self) -> bool {
1390 self != other
1391 }
1392}
1393
1394#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1396pub struct BenchTuneConfig {
1397 pub model_path: PathBuf,
1398 pub num_iterations: u32,
1399 pub prompt: String,
1400 pub params_to_test: Vec<BenchTuneParam>,
1401 pub test_duration: Duration,
1402 pub bench_mode: BenchTuneMode,
1403 pub n_predict: u32,
1404 pub chat_template_kwargs: Option<String>,
1405}
1406
1407#[derive(Debug, Clone, Serialize, Deserialize)]
1408pub struct BenchTuneParam {
1409 pub name: String,
1410 pub min: f64,
1411 pub max: f64,
1412 pub step: f64,
1413 pub enabled: bool,
1414}
1415
1416impl PartialEq for BenchTuneParam {
1417 fn eq(&self, other: &Self) -> bool {
1418 self.name == other.name
1419 && self.min.to_bits() == other.min.to_bits()
1420 && self.max.to_bits() == other.max.to_bits()
1421 && self.step.to_bits() == other.step.to_bits()
1422 && self.enabled == other.enabled
1423 }
1424}
1425impl Eq for BenchTuneParam {}
1426
1427#[derive(Debug, Clone, Serialize, Deserialize)]
1428pub struct BenchTuneParamValue {
1429 pub temperature: Option<f64>,
1430 pub top_p: Option<f64>,
1431 pub top_k: Option<i64>,
1432 pub repeat_penalty: Option<f64>,
1433 pub context_length: Option<u32>,
1434 pub batch_size: Option<u32>,
1435 pub flash_attn: Option<bool>,
1436 pub threads: Option<u32>,
1437 pub expert_count: Option<i32>,
1438 pub spec_type: Option<String>,
1439 pub draft_tokens: Option<u32>,
1440}
1441
1442impl PartialEq for BenchTuneParamValue {
1443 fn eq(&self, other: &Self) -> bool {
1444 self.temperature.map(|v| v.to_bits()) == other.temperature.map(|v| v.to_bits())
1445 && self.top_p.map(|v| v.to_bits()) == other.top_p.map(|v| v.to_bits())
1446 && self.top_k == other.top_k
1447 && self.repeat_penalty.map(|v| v.to_bits()) == other.repeat_penalty.map(|v| v.to_bits())
1448 && self.context_length == other.context_length
1449 && self.batch_size == other.batch_size
1450 && self.flash_attn == other.flash_attn
1451 && self.threads == other.threads
1452 && self.expert_count == other.expert_count
1453 && self.spec_type == other.spec_type
1454 && self.draft_tokens == other.draft_tokens
1455 }
1456}
1457impl Eq for BenchTuneParamValue {}
1458
1459#[derive(Debug, Clone, Serialize, Deserialize)]
1460pub struct BenchTuneResult {
1461 pub params: BenchTuneParamValue,
1462 pub metrics: BenchTuneMetrics,
1463 pub outputs: Vec<String>,
1464 pub per_iteration_metrics: Vec<BenchTuneMetrics>,
1465 pub base_settings: Option<ModelSettings>,
1466}
1467
1468#[derive(Debug, Clone, Serialize, Deserialize)]
1469pub struct BenchTuneMetrics {
1470 pub prompt_tps: f64,
1471 pub generation_tps: f64,
1472 pub combined_tps: f64,
1473 pub latency_per_token: f64,
1474 pub first_token_time: f64,
1475}
1476
1477#[derive(Debug, Clone, Serialize, Deserialize)]
1478pub enum BenchTuneStatus {
1479 Running {
1480 current: usize,
1481 total: usize,
1482 progress: f32,
1483 current_params: BenchTuneParamValue,
1484 },
1485 Completed {
1486 total_tests: usize,
1487 successful_tests: usize,
1488 elapsed: Duration,
1489 },
1490 PartiallyCompleted {
1491 total_tests: usize,
1492 successful_tests: usize,
1493 failed_tests: usize,
1494 elapsed: Duration,
1495 },
1496 Cancelled {
1497 total_tests: usize,
1498 successful_tests: usize,
1499 failed_tests: usize,
1500 elapsed: Duration,
1501 },
1502 Error {
1503 error: String,
1504 },
1505}
1506
1507#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
1508#[derive(Default)]
1509pub enum BenchTuneMode {
1510 RuntimeOnly,
1512 #[default]
1514 Full,
1515}
1516
1517
1518#[derive(Debug, Clone)]
1520pub enum BenchTuneProgress {
1521 Running {
1523 current: usize,
1524 total: usize,
1525 progress: f32,
1526 current_params: BenchTuneParamValue,
1527 },
1528 Completed {
1530 total_tests: usize,
1531 successful_tests: usize,
1532 elapsed: Duration,
1533 },
1534 PartiallyCompleted {
1536 total_tests: usize,
1537 successful_tests: usize,
1538 failed_tests: usize,
1539 elapsed: Duration,
1540 },
1541 Cancelled {
1543 total_tests: usize,
1544 successful_tests: usize,
1545 failed_tests: usize,
1546 elapsed: Duration,
1547 },
1548 Error { error: String },
1550}
1551
1552impl BenchTuneProgress {
1553 pub fn from_status(status: &BenchTuneStatus) -> Option<Self> {
1554 match status {
1555 BenchTuneStatus::Running {
1556 current,
1557 total,
1558 progress,
1559 current_params,
1560 } => Some(BenchTuneProgress::Running {
1561 current: *current,
1562 total: *total,
1563 progress: *progress,
1564 current_params: current_params.clone(),
1565 }),
1566 BenchTuneStatus::Completed {
1567 total_tests,
1568 successful_tests,
1569 elapsed,
1570 } => Some(BenchTuneProgress::Completed {
1571 total_tests: *total_tests,
1572 successful_tests: *successful_tests,
1573 elapsed: *elapsed,
1574 }),
1575 BenchTuneStatus::PartiallyCompleted {
1576 total_tests,
1577 successful_tests,
1578 failed_tests,
1579 elapsed,
1580 } => Some(BenchTuneProgress::PartiallyCompleted {
1581 total_tests: *total_tests,
1582 successful_tests: *successful_tests,
1583 failed_tests: *failed_tests,
1584 elapsed: *elapsed,
1585 }),
1586 BenchTuneStatus::Cancelled {
1587 total_tests,
1588 successful_tests,
1589 failed_tests,
1590 elapsed,
1591 } => Some(BenchTuneProgress::Cancelled {
1592 total_tests: *total_tests,
1593 successful_tests: *successful_tests,
1594 failed_tests: *failed_tests,
1595 elapsed: *elapsed,
1596 }),
1597 BenchTuneStatus::Error { error } => Some(BenchTuneProgress::Error {
1598 error: error.clone(),
1599 }),
1600 }
1601 }
1602}
1603
1604impl BenchTuneConfig {
1605 pub fn new(model_path: PathBuf, num_iterations: u32, prompt: String) -> Self {
1606 Self {
1607 model_path,
1608 num_iterations,
1609 prompt,
1610 params_to_test: vec![
1611 BenchTuneParam {
1612 name: "temperature".to_string(),
1613 min: 0.4,
1614 max: 1.0,
1615 step: 0.1,
1616 enabled: false,
1617 },
1618 BenchTuneParam {
1619 name: "top_p".to_string(),
1620 min: 0.8,
1621 max: 1.0,
1622 step: 0.1,
1623 enabled: false,
1624 },
1625 BenchTuneParam {
1626 name: "top_k".to_string(),
1627 min: 10.0,
1628 max: 40.0,
1629 step: 5.0,
1630 enabled: false,
1631 },
1632 BenchTuneParam {
1633 name: "repeat_penalty".to_string(),
1634 min: 1.0,
1635 max: 1.2,
1636 step: 0.1,
1637 enabled: false,
1638 },
1639 BenchTuneParam {
1640 name: "flash_attn".to_string(),
1641 min: 0.0,
1642 max: 1.0,
1643 step: 1.0,
1644 enabled: false,
1645 },
1646 BenchTuneParam {
1647 name: "threads".to_string(),
1648 min: 4.0,
1649 max: 16.0,
1650 step: 4.0,
1651 enabled: false,
1652 },
1653 BenchTuneParam {
1654 name: "batch_size".to_string(),
1655 min: 512.0,
1656 max: 2048.0,
1657 step: 512.0,
1658 enabled: false,
1659 },
1660 BenchTuneParam {
1661 name: "expert_count".to_string(),
1662 min: 1.0,
1663 max: 4.0,
1664 step: 1.0,
1665 enabled: false,
1666 },
1667 ],
1668 test_duration: Duration::from_secs(30),
1669 bench_mode: BenchTuneMode::default(),
1670 n_predict: 512,
1671 chat_template_kwargs: Some(r#"{"enable_thinking": false}"#.to_string()),
1672 }
1673 }
1674
1675 pub fn generate_combinations(&self) -> Vec<BenchTuneParamValue> {
1677 let mut temp_values = vec![None];
1678 let mut top_p_values = vec![None];
1679 let mut top_k_values = vec![None];
1680 let mut repeat_penalty_values = vec![None];
1681 let mut flash_attn_values = vec![None];
1682 let mut threads_values = vec![None];
1683 let mut batch_size_values = vec![None];
1684 let mut expert_count_values = vec![None];
1685
1686 for p in &self.params_to_test {
1687 if !p.enabled {
1688 continue;
1689 }
1690
1691 let vals: Vec<f64> = {
1692 let step_count = ((p.max - p.min) / p.step).ceil() as usize;
1693 (0..=step_count)
1694 .map(|i| (p.min + (i as f64 * p.step)).min(p.max))
1695 .collect()
1696 };
1697
1698 match p.name.as_str() {
1699 "temperature" => temp_values = vals.into_iter().map(Some).collect(),
1700 "top_p" => top_p_values = vals.into_iter().map(Some).collect(),
1701 "top_k" => top_k_values = vals.into_iter().map(|v| Some(v as i64)).collect(),
1702 "repeat_penalty" => repeat_penalty_values = vals.into_iter().map(Some).collect(),
1703 "flash_attn" => {
1704 flash_attn_values = vals.into_iter().map(|v| Some(v >= 0.5)).collect()
1705 }
1706 "threads" => threads_values = vals.into_iter().map(|v| Some(v as u32)).collect(),
1707 "batch_size" => {
1708 batch_size_values = vals.into_iter().map(|v| Some(v as u32)).collect()
1709 }
1710 "expert_count" => {
1711 expert_count_values = vals.into_iter().map(|v| Some(v as i32)).collect()
1712 }
1713 _ => {}
1714 }
1715 }
1716
1717 let mut combinations = Vec::new();
1718 for &temp in &temp_values {
1719 for &top_p in &top_p_values {
1720 for &top_k in &top_k_values {
1721 for &rp in &repeat_penalty_values {
1722 for &fa in &flash_attn_values {
1723 for &th in &threads_values {
1724 for &bs in &batch_size_values {
1725 for &ec in &expert_count_values {
1726 combinations.push(BenchTuneParamValue {
1727 temperature: temp,
1728 top_p,
1729 top_k,
1730 repeat_penalty: rp,
1731 context_length: None,
1732 batch_size: bs,
1733 flash_attn: fa,
1734 threads: th,
1735 expert_count: ec,
1736 spec_type: None,
1737 draft_tokens: None,
1738 });
1739 }
1740 }
1741 }
1742 }
1743 }
1744 }
1745 }
1746 }
1747
1748 combinations
1749 }
1750
1751 pub fn get_total_tests_count(&self) -> usize {
1753 self.generate_combinations().len()
1754 }
1755}
1756
1757#[cfg(test)]
1760mod field_count_tests {
1761 use super::*;
1762
1763 #[test]
1767 fn test_model_settings_field_count() {
1768 let s = ModelSettings::default();
1772 let field_count = count_model_settings_fields(&s);
1773 assert_eq!(
1774 field_count, 81,
1775 "ModelSettings has {} fields (expected 81). \
1776 Update the checklist at src/models.rs:665 and all locations listed there.",
1777 field_count
1778 );
1779}
1780
1781#[allow(clippy::too_many_lines)]
1785fn count_model_settings_fields(s: &ModelSettings) -> usize {
1786 let _ = (
1787 &s.context_length,
1788 &s.threads,
1789 &s.threads_batch,
1790 &s.batch_size,
1791 &s.ubatch_size,
1792 &s.parallel,
1793 &s.max_concurrent_predictions,
1794 &s.uniform_cache,
1795 &s.kv_cache_offload,
1796 &s.cache_type_k,
1797 &s.cache_type_v,
1798 &s.keep,
1799 &s.swa_full,
1800 &s.mlock,
1801 &s.mmap,
1802 &s.numa,
1803 &s.system_prompt,
1804 &s.system_prompt_preset_name,
1805 &s.gpu_layers_mode,
1806 &s.split_mode,
1807 &s.tensor_split,
1808 &s.main_gpu,
1809 &s.fit,
1810 &s.lora,
1811 &s.lora_scaled,
1812 &s.rpc,
1813 &s.embedding,
1814 &s.flash_attn,
1815 &s.expert_count,
1816 &s.jinja,
1817 &s.chat_template,
1818 &s.chat_template_kwargs,
1819 &s.seed,
1820 &s.temperature,
1821 &s.top_k,
1822 &s.top_p,
1823 &s.min_p,
1824 &s.typical_p,
1825 &s.mirostat,
1826 &s.mirostat_lr,
1827 &s.mirostat_ent,
1828 &s.ignore_eos,
1829 &s.samplers,
1830 &s.repeat_penalty,
1831 &s.repeat_last_n,
1832 &s.presence_penalty,
1833 &s.frequency_penalty,
1834 &s.dry_multiplier,
1835 &s.dry_base,
1836 &s.dry_allowed_length,
1837 &s.dry_penalty_last_n,
1838 &s.rope_scaling,
1839 &s.rope_scale,
1840 &s.rope_freq_base,
1841 &s.rope_freq_scale,
1842 &s.rope_yarn_enabled,
1843 &s.host,
1844 &s.port,
1845 &s.timeout,
1846 &s.cache_prompt,
1847 &s.cache_reuse,
1848 &s.webui,
1849 &s.max_tokens,
1850 &s.cache_type,
1851 &s.backend,
1852 &s.llama_cpp_version_cpu,
1853 &s.llama_cpp_version_vulkan,
1854 &s.llama_cpp_version_rocm,
1855 &s.llama_cpp_version_rocm_lemonade,
1856 &s.llama_cpp_version_cuda,
1857 &s.api_endpoint_enabled,
1858 &s.api_endpoint_port,
1859 &s.spec_type,
1860 &s.draft_tokens,
1861 &s.tags,
1862 &s.ws_server_enabled,
1863 &s.ws_server_port,
1864 &s.ws_server_auth_key,
1865 &s.ws_server_tls_enabled,
1866 &s.ws_server_tls_cert,
1867 &s.ws_server_tls_key,
1868 );
1869 81
1870}
1871
1872#[test]
1876fn test_is_dirty_uses_derived_eq() {
1877 let s1 = ModelSettings::default();
1878 let s2 = ModelSettings::default();
1879 let s3 = s1.clone();
1880
1881 assert!(!s1.is_dirty(&s2));
1883 assert!(!s1.is_dirty(&s3));
1884
1885 assert_eq!(s1 != s2, s1.is_dirty(&s2));
1887 assert_eq!(s1 != s3, s1.is_dirty(&s3));
1888}
1889
1890 #[test]
1893 fn test_from_default_params_completeness() {
1894 let dp = crate::config::DefaultParams::default();
1895 let ms: ModelSettings = dp.clone().into();
1896
1897 assert_eq!(ms.context_length, 131072);
1900 assert_eq!(ms.threads, dp.threads);
1901 assert_eq!(ms.temperature, 0.8);
1902 assert_eq!(ms.backend, dp.backend);
1905 }
1906 }