Skip to main content

llama_rs/
config.rs

1//! TOML configuration file support for llama-rs.
2//!
3//! Provides a [`Config`] struct that maps all CLI arguments to a TOML configuration
4//! file, enabling persistent and shareable inference setups.
5//!
6//! # Configuration Precedence (highest to lowest)
7//!
8//! 1. CLI arguments (always win)
9//! 2. Environment variables
10//! 3. TOML config file
11//! 4. Default values
12//!
13//! # Example TOML
14//!
15//! ```toml
16//! # llama-rs.toml
17//!
18//! [model]
19//! path = "/path/to/model.gguf"
20//! gpu = true
21//!
22//! [generation]
23//! temperature = 0.7
24//! top_k = 40
25//! top_p = 0.95
26//! repeat_penalty = 1.1
27//! max_tokens = 512
28//! seed = 42
29//!
30//! [chat]
31//! system_prompt = "You are a helpful AI assistant."
32//! max_tokens = 1024
33//!
34//! [server]
35//! host = "0.0.0.0"
36//! port = 8080
37//!
38//! [quantize]
39//! output_type = "q4_k"
40//! threads = 8
41//!
42//! [bench]
43//! n_prompt = 512
44//! n_gen = 128
45//! repetitions = 3
46//! threads = 4
47//!
48//! [embed]
49//! format = "json"
50//! ```
51
52use serde::{Deserialize, Serialize};
53use std::path::Path;
54
55use crate::engine::EngineConfig;
56
57// ============================================================================
58// Error type
59// ============================================================================
60
61/// Errors that can occur during configuration loading.
62#[derive(thiserror::Error, Debug)]
63pub enum ConfigError {
64    #[error("IO error: {0}")]
65    Io(#[from] std::io::Error),
66
67    #[error("TOML parse error: {0}")]
68    Toml(#[from] toml::de::Error),
69
70    #[error("TOML serialize error: {0}")]
71    TomlSerialize(#[from] toml::ser::Error),
72
73    #[error("Config error: {0}")]
74    Other(String),
75}
76
77// ============================================================================
78// Top-level configuration
79// ============================================================================
80
81/// Top-level TOML configuration covering all CLI arguments.
82///
83/// Each section corresponds to a subcommand or functional area.
84/// All fields are optional and fall back to sensible defaults.
85#[derive(Debug, Clone, Serialize, Deserialize, Default)]
86#[serde(default)]
87pub struct Config {
88    /// Model path and hardware settings.
89    pub model: ModelSection,
90
91    /// Generation/sampling parameters (shared by `run`, `chat`, `serve`).
92    pub generation: GenerationSection,
93
94    /// Chat-specific overrides.
95    pub chat: ChatSection,
96
97    /// HTTP server settings.
98    pub server: ServerSection,
99
100    /// Quantization settings.
101    pub quantize: QuantizeSection,
102
103    /// Benchmark settings.
104    pub bench: BenchSection,
105
106    /// Embedding extraction settings.
107    pub embed: EmbedSection,
108}
109
110// ============================================================================
111// Section structs
112// ============================================================================
113
114/// Model path and hardware configuration.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(default)]
117#[derive(Default)]
118pub struct ModelSection {
119    /// Path to the GGUF model file.
120    pub path: Option<String>,
121
122    /// Use GPU acceleration (CUDA/Metal/Vulkan).
123    pub gpu: bool,
124
125    /// KV cache type: "f32", "turboquant2", "turboquant3", "turboquant2-qjl", "turboquant3-qjl".
126    pub kv_cache_type: String,
127}
128
129/// Generation and sampling parameters.
130///
131/// These are shared across `run`, `chat`, and `serve` commands.
132/// Command-specific sections (like `[chat]`) can override individual values.
133#[derive(Debug, Clone, Serialize, Deserialize)]
134#[serde(default)]
135pub struct GenerationSection {
136    /// Temperature for sampling (0.0 = greedy, higher = more random).
137    pub temperature: f32,
138
139    /// Top-K sampling: only consider the K most likely tokens (0 = disabled).
140    pub top_k: usize,
141
142    /// Top-P (nucleus) sampling: cumulative probability cutoff.
143    pub top_p: f32,
144
145    /// Repetition penalty (1.0 = no penalty).
146    pub repeat_penalty: f32,
147
148    /// Default maximum tokens to generate.
149    pub max_tokens: usize,
150
151    /// Random seed for reproducible generation.
152    pub seed: Option<u64>,
153
154    /// Override the model's native max context length (0 = use model default).
155    /// Reduces KV cache memory for large-context models on constrained hardware.
156    pub max_context_len: usize,
157}
158
159impl Default for GenerationSection {
160    fn default() -> Self {
161        Self {
162            temperature: 0.7,
163            top_k: 40,
164            top_p: 0.95,
165            repeat_penalty: 1.1,
166            max_tokens: 512,
167            seed: None,
168            max_context_len: 0,
169        }
170    }
171}
172
173/// Chat-specific configuration overrides.
174///
175/// Values here override the corresponding `[generation]` values
176/// when running the `chat` command.
177#[derive(Debug, Clone, Serialize, Deserialize, Default)]
178#[serde(default)]
179pub struct ChatSection {
180    /// System prompt for chat sessions.
181    pub system_prompt: Option<String>,
182
183    /// Override max_tokens for chat (defaults to `generation.max_tokens`).
184    pub max_tokens: Option<usize>,
185
186    /// Override temperature for chat (defaults to `generation.temperature`).
187    pub temperature: Option<f32>,
188
189    /// Override top_p for chat (defaults to `generation.top_p`).
190    pub top_p: Option<f32>,
191
192    /// Override top_k for chat (defaults to `generation.top_k`).
193    pub top_k: Option<usize>,
194
195    /// Override repeat_penalty for chat (defaults to `generation.repeat_penalty`).
196    pub repeat_penalty: Option<f32>,
197}
198
199/// HTTP server configuration.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(default)]
202pub struct ServerSection {
203    /// Host address to bind to.
204    pub host: String,
205
206    /// Port to listen on.
207    pub port: u16,
208
209    /// RAG database URL (PostgreSQL with pgvector).
210    pub rag_database_url: Option<String>,
211
212    /// Path to RAG configuration file.
213    pub rag_config: Option<String>,
214}
215
216impl Default for ServerSection {
217    fn default() -> Self {
218        Self {
219            host: "127.0.0.1".to_string(),
220            port: 8080,
221            rag_database_url: None,
222            rag_config: None,
223        }
224    }
225}
226
227impl ServerSection {
228    /// Build a full URL from host and port (e.g. `http://192.168.1.4:8080`).
229    /// Returns `None` if host is localhost/default (i.e. no remote server configured).
230    pub fn host_url(&self) -> Option<String> {
231        // Only return a URL if the host is NOT localhost — a remote was explicitly configured
232        if self.host == "127.0.0.1" || self.host == "localhost" || self.host == "0.0.0.0" {
233            None
234        } else {
235            Some(format!("http://{}:{}", self.host, self.port))
236        }
237    }
238}
239
240/// Quantization settings.
241#[derive(Debug, Clone, Serialize, Deserialize)]
242#[serde(default)]
243pub struct QuantizeSection {
244    /// Target quantization type (q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k, q5_k, q6_k).
245    pub output_type: String,
246
247    /// Number of threads to use.
248    pub threads: Option<usize>,
249}
250
251impl Default for QuantizeSection {
252    fn default() -> Self {
253        Self {
254            output_type: "q4_0".to_string(),
255            threads: None,
256        }
257    }
258}
259
260/// Benchmark settings.
261#[derive(Debug, Clone, Serialize, Deserialize)]
262#[serde(default)]
263pub struct BenchSection {
264    /// Number of prompt tokens to process.
265    pub n_prompt: usize,
266
267    /// Number of tokens to generate.
268    pub n_gen: usize,
269
270    /// Number of repetitions for averaging.
271    pub repetitions: usize,
272
273    /// Number of threads to use.
274    pub threads: Option<usize>,
275}
276
277impl Default for BenchSection {
278    fn default() -> Self {
279        Self {
280            n_prompt: 512,
281            n_gen: 128,
282            repetitions: 3,
283            threads: None,
284        }
285    }
286}
287
288/// Embedding extraction settings.
289#[derive(Debug, Clone, Serialize, Deserialize)]
290#[serde(default)]
291pub struct EmbedSection {
292    /// Output format: "json" or "raw".
293    pub format: String,
294}
295
296impl Default for EmbedSection {
297    fn default() -> Self {
298        Self {
299            format: "json".to_string(),
300        }
301    }
302}
303
304// ============================================================================
305// Config implementation
306// ============================================================================
307
308/// Default config file names to search for (in order).
309pub const DEFAULT_CONFIG_PATHS: &[&str] = &[
310    "llama-rs.toml",
311    "config/llama-rs.toml",
312    ".llama-rs.toml",
313];
314
315impl Config {
316    /// Load configuration from a TOML file.
317    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
318        let content = std::fs::read_to_string(path.as_ref())?;
319        let config: Self = toml::from_str(&content)?;
320        Ok(config)
321    }
322
323    /// Load configuration from environment variables.
324    ///
325    /// Supported variables:
326    /// - `LLAMA_MODEL_PATH` - Path to GGUF model file
327    /// - `LLAMA_GPU` - Enable GPU acceleration ("1", "true", "yes")
328    /// - `LLAMA_TEMPERATURE` - Sampling temperature
329    /// - `LLAMA_TOP_K` - Top-K sampling value
330    /// - `LLAMA_TOP_P` - Top-P sampling value
331    /// - `LLAMA_REPEAT_PENALTY` - Repetition penalty
332    /// - `LLAMA_MAX_TOKENS` - Maximum tokens to generate
333    /// - `LLAMA_SEED` - Random seed
334    /// - `LLAMA_HOST` - Server host address
335    /// - `LLAMA_PORT` - Server port
336    /// - `LLAMA_SYSTEM_PROMPT` - Default system prompt for chat
337    pub fn from_env() -> Self {
338        let mut config = Self::default();
339
340        if let Ok(path) = std::env::var("LLAMA_MODEL_PATH") {
341            config.model.path = Some(path);
342        }
343        if let Ok(gpu) = std::env::var("LLAMA_GPU") {
344            config.model.gpu = matches!(gpu.to_lowercase().as_str(), "1" | "true" | "yes");
345        }
346        if let Ok(val) = std::env::var("LLAMA_TEMPERATURE")
347            && let Ok(v) = val.parse()
348        {
349            config.generation.temperature = v;
350        }
351        if let Ok(val) = std::env::var("LLAMA_TOP_K")
352            && let Ok(v) = val.parse()
353        {
354            config.generation.top_k = v;
355        }
356        if let Ok(val) = std::env::var("LLAMA_TOP_P")
357            && let Ok(v) = val.parse()
358        {
359            config.generation.top_p = v;
360        }
361        if let Ok(val) = std::env::var("LLAMA_REPEAT_PENALTY")
362            && let Ok(v) = val.parse()
363        {
364            config.generation.repeat_penalty = v;
365        }
366        if let Ok(val) = std::env::var("LLAMA_MAX_TOKENS")
367            && let Ok(v) = val.parse()
368        {
369            config.generation.max_tokens = v;
370        }
371        if let Ok(val) = std::env::var("LLAMA_SEED")
372            && let Ok(v) = val.parse()
373        {
374            config.generation.seed = Some(v);
375        }
376        if let Ok(val) = std::env::var("LLAMA_HOST") {
377            config.server.host = val;
378        }
379        if let Ok(val) = std::env::var("LLAMA_PORT")
380            && let Ok(v) = val.parse()
381        {
382            config.server.port = v;
383        }
384        if let Ok(val) = std::env::var("LLAMA_SYSTEM_PROMPT") {
385            config.chat.system_prompt = Some(val);
386        }
387
388        config
389    }
390
391    /// Load configuration with full precedence chain.
392    ///
393    /// 1. If `config_path` is provided, load from that file.
394    /// 2. Otherwise, search default locations (`llama-rs.toml`, etc.).
395    /// 3. Override with environment variables.
396    pub fn load(config_path: Option<impl AsRef<Path>>) -> Result<Self, ConfigError> {
397        let mut config = Self::default();
398
399        // Try to load from explicit path
400        if let Some(path) = config_path {
401            let p = path.as_ref();
402            if p.exists() {
403                config = Self::from_file(p)?;
404            } else {
405                return Err(ConfigError::Other(format!(
406                    "Config file not found: {}",
407                    p.display()
408                )));
409            }
410        } else {
411            // Search default locations
412            for path in DEFAULT_CONFIG_PATHS {
413                if Path::new(path).exists() {
414                    config = Self::from_file(path)?;
415                    break;
416                }
417            }
418        }
419
420        // Apply environment variable overrides
421        config.apply_env();
422
423        Ok(config)
424    }
425
426    /// Apply environment variable overrides to the current configuration.
427    pub fn apply_env(&mut self) {
428        if let Ok(path) = std::env::var("LLAMA_MODEL_PATH") {
429            self.model.path = Some(path);
430        }
431        if let Ok(gpu) = std::env::var("LLAMA_GPU") {
432            self.model.gpu = matches!(gpu.to_lowercase().as_str(), "1" | "true" | "yes");
433        }
434        if let Ok(val) = std::env::var("LLAMA_TEMPERATURE")
435            && let Ok(v) = val.parse()
436        {
437            self.generation.temperature = v;
438        }
439        if let Ok(val) = std::env::var("LLAMA_TOP_K")
440            && let Ok(v) = val.parse()
441        {
442            self.generation.top_k = v;
443        }
444        if let Ok(val) = std::env::var("LLAMA_TOP_P")
445            && let Ok(v) = val.parse()
446        {
447            self.generation.top_p = v;
448        }
449        if let Ok(val) = std::env::var("LLAMA_REPEAT_PENALTY")
450            && let Ok(v) = val.parse()
451        {
452            self.generation.repeat_penalty = v;
453        }
454        if let Ok(val) = std::env::var("LLAMA_MAX_TOKENS")
455            && let Ok(v) = val.parse()
456        {
457            self.generation.max_tokens = v;
458        }
459        if let Ok(val) = std::env::var("LLAMA_SEED")
460            && let Ok(v) = val.parse()
461        {
462            self.generation.seed = Some(v);
463        }
464        if let Ok(val) = std::env::var("LLAMA_HOST") {
465            self.server.host = val;
466        }
467        if let Ok(val) = std::env::var("LLAMA_PORT")
468            && let Ok(v) = val.parse()
469        {
470            self.server.port = v;
471        }
472        if let Ok(val) = std::env::var("LLAMA_SYSTEM_PROMPT") {
473            self.chat.system_prompt = Some(val);
474        }
475    }
476
477    /// Save configuration to a TOML file.
478    pub fn save(&self, path: impl AsRef<Path>) -> Result<(), ConfigError> {
479        let content = toml::to_string_pretty(self)?;
480        std::fs::write(path, content)?;
481        Ok(())
482    }
483
484    /// Convert to an [`EngineConfig`] using the model and generation sections.
485    ///
486    /// The `model_path_override` takes highest priority (from CLI positional arg).
487    pub fn to_engine_config(&self, model_path_override: Option<&str>) -> EngineConfig {
488        let model_path = model_path_override
489            .map(|s| s.to_string())
490            .or_else(|| self.model.path.clone())
491            .unwrap_or_default();
492
493        EngineConfig {
494            model_path,
495            tokenizer_path: None,
496            temperature: self.generation.temperature,
497            top_k: self.generation.top_k,
498            top_p: self.generation.top_p,
499            repeat_penalty: self.generation.repeat_penalty,
500            max_tokens: self.generation.max_tokens,
501            seed: self.generation.seed,
502            use_gpu: self.model.gpu,
503            max_context_len: None,
504            #[cfg(feature = "hailo")]
505            hailo_config: None,
506            kv_cache_type: parse_kv_cache_type(&self.model.kv_cache_type),
507        }
508    }
509
510    /// Convert to an [`EngineConfig`] using chat-specific overrides where present.
511    ///
512    /// Values from `[chat]` override `[generation]` when set.
513    pub fn to_chat_engine_config(&self, model_path_override: Option<&str>) -> EngineConfig {
514        let mut config = self.to_engine_config(model_path_override);
515
516        // Apply chat-specific overrides
517        if let Some(max_tokens) = self.chat.max_tokens {
518            config.max_tokens = max_tokens;
519        }
520        if let Some(temperature) = self.chat.temperature {
521            config.temperature = temperature;
522        }
523        if let Some(top_p) = self.chat.top_p {
524            config.top_p = top_p;
525        }
526        if let Some(top_k) = self.chat.top_k {
527            config.top_k = top_k;
528        }
529        if let Some(repeat_penalty) = self.chat.repeat_penalty {
530            config.repeat_penalty = repeat_penalty;
531        }
532
533        config
534    }
535}
536
537// ============================================================================
538// Example config generator
539// ============================================================================
540
541/// Generate an example TOML configuration with all options documented.
542pub fn example_config() -> &'static str {
543    r#"# llama-rs configuration
544# All values shown are defaults. Uncomment and modify as needed.
545#
546# Precedence: CLI arguments > environment variables > this file > defaults
547
548# ─────────────────────────────────────────────────────────────────────
549# Model
550# ─────────────────────────────────────────────────────────────────────
551[model]
552# Path to the GGUF model file (can also use LLAMA_MODEL_PATH env var)
553# path = "/path/to/model.gguf"
554
555# Use GPU acceleration (CUDA/Metal/Vulkan)
556# Also: LLAMA_GPU=1
557gpu = false
558
559# ─────────────────────────────────────────────────────────────────────
560# Generation / Sampling Parameters
561# Used by: run, chat, serve
562# ─────────────────────────────────────────────────────────────────────
563[generation]
564# Sampling temperature (0.0 = greedy/deterministic, higher = more random)
565# Also: LLAMA_TEMPERATURE
566temperature = 0.7
567
568# Top-K sampling: only consider the K most likely next tokens (0 = disabled)
569# Also: LLAMA_TOP_K
570top_k = 40
571
572# Top-P (nucleus) sampling: cumulative probability cutoff
573# Also: LLAMA_TOP_P
574top_p = 0.95
575
576# Repetition penalty (1.0 = no penalty, higher = less repetition)
577# Also: LLAMA_REPEAT_PENALTY
578repeat_penalty = 1.1
579
580# Default maximum tokens to generate per request
581# Also: LLAMA_MAX_TOKENS
582max_tokens = 512
583
584# Random seed for reproducible generation (comment out for random)
585# Also: LLAMA_SEED
586# seed = 42
587
588# ─────────────────────────────────────────────────────────────────────
589# Chat Mode Overrides
590# Values here override [generation] when using the `chat` command.
591# Omitted values fall back to [generation].
592# ─────────────────────────────────────────────────────────────────────
593[chat]
594# Default system prompt for chat sessions
595# Also: LLAMA_SYSTEM_PROMPT
596# system_prompt = "You are a helpful AI assistant."
597
598# Override generation settings for chat specifically
599# max_tokens = 1024
600# temperature = 0.7
601# top_p = 0.9
602# top_k = 40
603# repeat_penalty = 1.1
604
605# ─────────────────────────────────────────────────────────────────────
606# HTTP Server (used by `serve` command)
607# ─────────────────────────────────────────────────────────────────────
608[server]
609# Host address to bind to
610# Also: LLAMA_HOST
611host = "127.0.0.1"
612
613# Port to listen on
614# Also: LLAMA_PORT
615port = 8080
616
617# PostgreSQL/pgvector URL for RAG (requires `rag` feature)
618# Also: RAG_DATABASE_URL
619# rag_database_url = "postgres://user:pass@localhost:5432/mydb"
620
621# Path to separate RAG config file
622# rag_config = "rag.toml"
623
624# ─────────────────────────────────────────────────────────────────────
625# Quantization (used by `quantize` command)
626# ─────────────────────────────────────────────────────────────────────
627[quantize]
628# Target quantization type
629# Options: q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k, q5_k, q6_k
630output_type = "q4_0"
631
632# Number of threads (default: all available cores)
633# threads = 8
634
635# ─────────────────────────────────────────────────────────────────────
636# Benchmarking (used by `bench` command)
637# ─────────────────────────────────────────────────────────────────────
638[bench]
639# Number of prompt tokens to process
640n_prompt = 512
641
642# Number of tokens to generate
643n_gen = 128
644
645# Number of repetitions for averaging results
646repetitions = 3
647
648# Number of threads (default: all available cores)
649# threads = 4
650
651# ─────────────────────────────────────────────────────────────────────
652# Embeddings (used by `embed` command)
653# ─────────────────────────────────────────────────────────────────────
654[embed]
655# Output format: "json" or "raw"
656format = "json"
657"#
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[test]
665    fn test_default_config() {
666        let config = Config::default();
667        assert_eq!(config.generation.temperature, 0.7);
668        assert_eq!(config.generation.top_k, 40);
669        assert_eq!(config.generation.top_p, 0.95);
670        assert_eq!(config.generation.max_tokens, 512);
671        assert_eq!(config.server.port, 8080);
672        assert!(!config.model.gpu);
673    }
674
675    #[test]
676    fn test_roundtrip_toml() {
677        let config = Config {
678            model: ModelSection {
679                path: Some("/tmp/test.gguf".to_string()),
680                gpu: true,
681                ..Default::default()
682            },
683            generation: GenerationSection {
684                temperature: 0.5,
685                top_k: 50,
686                seed: Some(42),
687                ..Default::default()
688            },
689            ..Default::default()
690        };
691
692        let toml_str = toml::to_string_pretty(&config).unwrap();
693        let parsed: Config = toml::from_str(&toml_str).unwrap();
694
695        assert_eq!(parsed.model.path.as_deref(), Some("/tmp/test.gguf"));
696        assert!(parsed.model.gpu);
697        assert_eq!(parsed.generation.temperature, 0.5);
698        assert_eq!(parsed.generation.top_k, 50);
699        assert_eq!(parsed.generation.seed, Some(42));
700    }
701
702    #[test]
703    fn test_to_engine_config() {
704        let config = Config {
705            model: ModelSection {
706                path: Some("/models/llama.gguf".to_string()),
707                gpu: true,
708                ..Default::default()
709            },
710            generation: GenerationSection {
711                temperature: 0.3,
712                max_tokens: 1024,
713                seed: Some(123),
714                ..Default::default()
715            },
716            ..Default::default()
717        };
718
719        let engine = config.to_engine_config(None);
720        assert_eq!(engine.model_path, "/models/llama.gguf");
721        assert_eq!(engine.temperature, 0.3);
722        assert_eq!(engine.max_tokens, 1024);
723        assert_eq!(engine.seed, Some(123));
724        assert!(engine.use_gpu);
725    }
726
727    #[test]
728    fn test_model_path_override() {
729        let config = Config {
730            model: ModelSection {
731                path: Some("/config/model.gguf".to_string()),
732                ..Default::default()
733            },
734            ..Default::default()
735        };
736
737        // CLI override should win
738        let engine = config.to_engine_config(Some("/cli/model.gguf"));
739        assert_eq!(engine.model_path, "/cli/model.gguf");
740
741        // Without override, config value is used
742        let engine = config.to_engine_config(None);
743        assert_eq!(engine.model_path, "/config/model.gguf");
744    }
745
746    #[test]
747    fn test_chat_overrides() {
748        let config = Config {
749            generation: GenerationSection {
750                temperature: 0.8,
751                max_tokens: 256,
752                ..Default::default()
753            },
754            chat: ChatSection {
755                max_tokens: Some(1024),
756                temperature: Some(0.5),
757                ..Default::default()
758            },
759            ..Default::default()
760        };
761
762        let engine = config.to_chat_engine_config(None);
763        assert_eq!(engine.max_tokens, 1024); // overridden by chat
764        assert_eq!(engine.temperature, 0.5); // overridden by chat
765        assert_eq!(engine.top_k, 40); // from generation defaults
766    }
767
768    #[test]
769    fn test_parse_partial_toml() {
770        let toml_str = r#"
771[model]
772path = "/my/model.gguf"
773
774[generation]
775temperature = 0.3
776"#;
777
778        let config: Config = toml::from_str(toml_str).unwrap();
779        assert_eq!(config.model.path.as_deref(), Some("/my/model.gguf"));
780        assert_eq!(config.generation.temperature, 0.3);
781        // Defaults should fill in the rest
782        assert_eq!(config.generation.top_k, 40);
783        assert_eq!(config.server.port, 8080);
784    }
785
786    #[test]
787    fn test_example_config_parses() {
788        let example = example_config();
789        assert!(example.contains("[model]"));
790        assert!(example.contains("[generation]"));
791        assert!(example.contains("[chat]"));
792        assert!(example.contains("[server]"));
793    }
794
795    #[test]
796    fn test_parse_kv_cache_type() {
797        use crate::model::KVCacheType;
798        assert_eq!(parse_kv_cache_type("f32"), KVCacheType::F32);
799        assert_eq!(parse_kv_cache_type("turboquant2"), KVCacheType::TurboQuantMSE { bits: 2 });
800        assert_eq!(parse_kv_cache_type("turboquant3"), KVCacheType::TurboQuantMSE { bits: 3 });
801        assert_eq!(parse_kv_cache_type("turboquant2-qjl"), KVCacheType::TurboQuantProd { bits: 2 });
802        assert_eq!(parse_kv_cache_type("turboquant3-qjl"), KVCacheType::TurboQuantProd { bits: 3 });
803        assert_eq!(parse_kv_cache_type(""), KVCacheType::F32);
804    }
805}
806
807/// Parse a KV cache type string into a `KVCacheType` enum.
808pub fn parse_kv_cache_type(s: &str) -> crate::model::KVCacheType {
809    use crate::model::KVCacheType;
810    match s.to_lowercase().as_str() {
811        "turboquant2" | "tq2" => KVCacheType::TurboQuantMSE { bits: 2 },
812        "turboquant3" | "tq3" => KVCacheType::TurboQuantMSE { bits: 3 },
813        "turboquant2-qjl" | "tq2-qjl" => KVCacheType::TurboQuantProd { bits: 2 },
814        "turboquant3-qjl" | "tq3-qjl" => KVCacheType::TurboQuantProd { bits: 3 },
815        _ => KVCacheType::F32,
816    }
817}