1use serde::{Deserialize, Serialize};
53use std::path::Path;
54
55use crate::engine::EngineConfig;
56
57#[derive(thiserror::Error, Debug)]
63pub enum ConfigError {
64 #[error("IO error: {0}")]
65 Io(#[from] std::io::Error),
66
67 #[error("TOML parse error: {0}")]
68 Toml(#[from] toml::de::Error),
69
70 #[error("TOML serialize error: {0}")]
71 TomlSerialize(#[from] toml::ser::Error),
72
73 #[error("Config error: {0}")]
74 Other(String),
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, Default)]
86#[serde(default)]
87pub struct Config {
88 pub model: ModelSection,
90
91 pub generation: GenerationSection,
93
94 pub chat: ChatSection,
96
97 pub server: ServerSection,
99
100 pub quantize: QuantizeSection,
102
103 pub bench: BenchSection,
105
106 pub embed: EmbedSection,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
116#[serde(default)]
117#[derive(Default)]
118pub struct ModelSection {
119 pub path: Option<String>,
121
122 pub gpu: bool,
124
125 pub kv_cache_type: String,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
134#[serde(default)]
135pub struct GenerationSection {
136 pub temperature: f32,
138
139 pub top_k: usize,
141
142 pub top_p: f32,
144
145 pub repeat_penalty: f32,
147
148 pub max_tokens: usize,
150
151 pub seed: Option<u64>,
153
154 pub max_context_len: usize,
157}
158
159impl Default for GenerationSection {
160 fn default() -> Self {
161 Self {
162 temperature: 0.7,
163 top_k: 40,
164 top_p: 0.95,
165 repeat_penalty: 1.1,
166 max_tokens: 512,
167 seed: None,
168 max_context_len: 0,
169 }
170 }
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize, Default)]
178#[serde(default)]
179pub struct ChatSection {
180 pub system_prompt: Option<String>,
182
183 pub max_tokens: Option<usize>,
185
186 pub temperature: Option<f32>,
188
189 pub top_p: Option<f32>,
191
192 pub top_k: Option<usize>,
194
195 pub repeat_penalty: Option<f32>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201#[serde(default)]
202pub struct ServerSection {
203 pub host: String,
205
206 pub port: u16,
208
209 pub rag_database_url: Option<String>,
211
212 pub rag_config: Option<String>,
214}
215
216impl Default for ServerSection {
217 fn default() -> Self {
218 Self {
219 host: "127.0.0.1".to_string(),
220 port: 8080,
221 rag_database_url: None,
222 rag_config: None,
223 }
224 }
225}
226
227impl ServerSection {
228 pub fn host_url(&self) -> Option<String> {
231 if self.host == "127.0.0.1" || self.host == "localhost" || self.host == "0.0.0.0" {
233 None
234 } else {
235 Some(format!("http://{}:{}", self.host, self.port))
236 }
237 }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
242#[serde(default)]
243pub struct QuantizeSection {
244 pub output_type: String,
246
247 pub threads: Option<usize>,
249}
250
251impl Default for QuantizeSection {
252 fn default() -> Self {
253 Self {
254 output_type: "q4_0".to_string(),
255 threads: None,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262#[serde(default)]
263pub struct BenchSection {
264 pub n_prompt: usize,
266
267 pub n_gen: usize,
269
270 pub repetitions: usize,
272
273 pub threads: Option<usize>,
275}
276
277impl Default for BenchSection {
278 fn default() -> Self {
279 Self {
280 n_prompt: 512,
281 n_gen: 128,
282 repetitions: 3,
283 threads: None,
284 }
285 }
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290#[serde(default)]
291pub struct EmbedSection {
292 pub format: String,
294}
295
296impl Default for EmbedSection {
297 fn default() -> Self {
298 Self {
299 format: "json".to_string(),
300 }
301 }
302}
303
304pub const DEFAULT_CONFIG_PATHS: &[&str] = &[
310 "llama-rs.toml",
311 "config/llama-rs.toml",
312 ".llama-rs.toml",
313];
314
315impl Config {
316 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
318 let content = std::fs::read_to_string(path.as_ref())?;
319 let config: Self = toml::from_str(&content)?;
320 Ok(config)
321 }
322
323 pub fn from_env() -> Self {
338 let mut config = Self::default();
339
340 if let Ok(path) = std::env::var("LLAMA_MODEL_PATH") {
341 config.model.path = Some(path);
342 }
343 if let Ok(gpu) = std::env::var("LLAMA_GPU") {
344 config.model.gpu = matches!(gpu.to_lowercase().as_str(), "1" | "true" | "yes");
345 }
346 if let Ok(val) = std::env::var("LLAMA_TEMPERATURE")
347 && let Ok(v) = val.parse()
348 {
349 config.generation.temperature = v;
350 }
351 if let Ok(val) = std::env::var("LLAMA_TOP_K")
352 && let Ok(v) = val.parse()
353 {
354 config.generation.top_k = v;
355 }
356 if let Ok(val) = std::env::var("LLAMA_TOP_P")
357 && let Ok(v) = val.parse()
358 {
359 config.generation.top_p = v;
360 }
361 if let Ok(val) = std::env::var("LLAMA_REPEAT_PENALTY")
362 && let Ok(v) = val.parse()
363 {
364 config.generation.repeat_penalty = v;
365 }
366 if let Ok(val) = std::env::var("LLAMA_MAX_TOKENS")
367 && let Ok(v) = val.parse()
368 {
369 config.generation.max_tokens = v;
370 }
371 if let Ok(val) = std::env::var("LLAMA_SEED")
372 && let Ok(v) = val.parse()
373 {
374 config.generation.seed = Some(v);
375 }
376 if let Ok(val) = std::env::var("LLAMA_HOST") {
377 config.server.host = val;
378 }
379 if let Ok(val) = std::env::var("LLAMA_PORT")
380 && let Ok(v) = val.parse()
381 {
382 config.server.port = v;
383 }
384 if let Ok(val) = std::env::var("LLAMA_SYSTEM_PROMPT") {
385 config.chat.system_prompt = Some(val);
386 }
387
388 config
389 }
390
391 pub fn load(config_path: Option<impl AsRef<Path>>) -> Result<Self, ConfigError> {
397 let mut config = Self::default();
398
399 if let Some(path) = config_path {
401 let p = path.as_ref();
402 if p.exists() {
403 config = Self::from_file(p)?;
404 } else {
405 return Err(ConfigError::Other(format!(
406 "Config file not found: {}",
407 p.display()
408 )));
409 }
410 } else {
411 for path in DEFAULT_CONFIG_PATHS {
413 if Path::new(path).exists() {
414 config = Self::from_file(path)?;
415 break;
416 }
417 }
418 }
419
420 config.apply_env();
422
423 Ok(config)
424 }
425
426 pub fn apply_env(&mut self) {
428 if let Ok(path) = std::env::var("LLAMA_MODEL_PATH") {
429 self.model.path = Some(path);
430 }
431 if let Ok(gpu) = std::env::var("LLAMA_GPU") {
432 self.model.gpu = matches!(gpu.to_lowercase().as_str(), "1" | "true" | "yes");
433 }
434 if let Ok(val) = std::env::var("LLAMA_TEMPERATURE")
435 && let Ok(v) = val.parse()
436 {
437 self.generation.temperature = v;
438 }
439 if let Ok(val) = std::env::var("LLAMA_TOP_K")
440 && let Ok(v) = val.parse()
441 {
442 self.generation.top_k = v;
443 }
444 if let Ok(val) = std::env::var("LLAMA_TOP_P")
445 && let Ok(v) = val.parse()
446 {
447 self.generation.top_p = v;
448 }
449 if let Ok(val) = std::env::var("LLAMA_REPEAT_PENALTY")
450 && let Ok(v) = val.parse()
451 {
452 self.generation.repeat_penalty = v;
453 }
454 if let Ok(val) = std::env::var("LLAMA_MAX_TOKENS")
455 && let Ok(v) = val.parse()
456 {
457 self.generation.max_tokens = v;
458 }
459 if let Ok(val) = std::env::var("LLAMA_SEED")
460 && let Ok(v) = val.parse()
461 {
462 self.generation.seed = Some(v);
463 }
464 if let Ok(val) = std::env::var("LLAMA_HOST") {
465 self.server.host = val;
466 }
467 if let Ok(val) = std::env::var("LLAMA_PORT")
468 && let Ok(v) = val.parse()
469 {
470 self.server.port = v;
471 }
472 if let Ok(val) = std::env::var("LLAMA_SYSTEM_PROMPT") {
473 self.chat.system_prompt = Some(val);
474 }
475 }
476
477 pub fn save(&self, path: impl AsRef<Path>) -> Result<(), ConfigError> {
479 let content = toml::to_string_pretty(self)?;
480 std::fs::write(path, content)?;
481 Ok(())
482 }
483
484 pub fn to_engine_config(&self, model_path_override: Option<&str>) -> EngineConfig {
488 let model_path = model_path_override
489 .map(|s| s.to_string())
490 .or_else(|| self.model.path.clone())
491 .unwrap_or_default();
492
493 EngineConfig {
494 model_path,
495 tokenizer_path: None,
496 temperature: self.generation.temperature,
497 top_k: self.generation.top_k,
498 top_p: self.generation.top_p,
499 repeat_penalty: self.generation.repeat_penalty,
500 max_tokens: self.generation.max_tokens,
501 seed: self.generation.seed,
502 use_gpu: self.model.gpu,
503 max_context_len: None,
504 #[cfg(feature = "hailo")]
505 hailo_config: None,
506 kv_cache_type: parse_kv_cache_type(&self.model.kv_cache_type),
507 }
508 }
509
510 pub fn to_chat_engine_config(&self, model_path_override: Option<&str>) -> EngineConfig {
514 let mut config = self.to_engine_config(model_path_override);
515
516 if let Some(max_tokens) = self.chat.max_tokens {
518 config.max_tokens = max_tokens;
519 }
520 if let Some(temperature) = self.chat.temperature {
521 config.temperature = temperature;
522 }
523 if let Some(top_p) = self.chat.top_p {
524 config.top_p = top_p;
525 }
526 if let Some(top_k) = self.chat.top_k {
527 config.top_k = top_k;
528 }
529 if let Some(repeat_penalty) = self.chat.repeat_penalty {
530 config.repeat_penalty = repeat_penalty;
531 }
532
533 config
534 }
535}
536
537pub fn example_config() -> &'static str {
543 r#"# llama-rs configuration
544# All values shown are defaults. Uncomment and modify as needed.
545#
546# Precedence: CLI arguments > environment variables > this file > defaults
547
548# ─────────────────────────────────────────────────────────────────────
549# Model
550# ─────────────────────────────────────────────────────────────────────
551[model]
552# Path to the GGUF model file (can also use LLAMA_MODEL_PATH env var)
553# path = "/path/to/model.gguf"
554
555# Use GPU acceleration (CUDA/Metal/Vulkan)
556# Also: LLAMA_GPU=1
557gpu = false
558
559# ─────────────────────────────────────────────────────────────────────
560# Generation / Sampling Parameters
561# Used by: run, chat, serve
562# ─────────────────────────────────────────────────────────────────────
563[generation]
564# Sampling temperature (0.0 = greedy/deterministic, higher = more random)
565# Also: LLAMA_TEMPERATURE
566temperature = 0.7
567
568# Top-K sampling: only consider the K most likely next tokens (0 = disabled)
569# Also: LLAMA_TOP_K
570top_k = 40
571
572# Top-P (nucleus) sampling: cumulative probability cutoff
573# Also: LLAMA_TOP_P
574top_p = 0.95
575
576# Repetition penalty (1.0 = no penalty, higher = less repetition)
577# Also: LLAMA_REPEAT_PENALTY
578repeat_penalty = 1.1
579
580# Default maximum tokens to generate per request
581# Also: LLAMA_MAX_TOKENS
582max_tokens = 512
583
584# Random seed for reproducible generation (comment out for random)
585# Also: LLAMA_SEED
586# seed = 42
587
588# ─────────────────────────────────────────────────────────────────────
589# Chat Mode Overrides
590# Values here override [generation] when using the `chat` command.
591# Omitted values fall back to [generation].
592# ─────────────────────────────────────────────────────────────────────
593[chat]
594# Default system prompt for chat sessions
595# Also: LLAMA_SYSTEM_PROMPT
596# system_prompt = "You are a helpful AI assistant."
597
598# Override generation settings for chat specifically
599# max_tokens = 1024
600# temperature = 0.7
601# top_p = 0.9
602# top_k = 40
603# repeat_penalty = 1.1
604
605# ─────────────────────────────────────────────────────────────────────
606# HTTP Server (used by `serve` command)
607# ─────────────────────────────────────────────────────────────────────
608[server]
609# Host address to bind to
610# Also: LLAMA_HOST
611host = "127.0.0.1"
612
613# Port to listen on
614# Also: LLAMA_PORT
615port = 8080
616
617# PostgreSQL/pgvector URL for RAG (requires `rag` feature)
618# Also: RAG_DATABASE_URL
619# rag_database_url = "postgres://user:pass@localhost:5432/mydb"
620
621# Path to separate RAG config file
622# rag_config = "rag.toml"
623
624# ─────────────────────────────────────────────────────────────────────
625# Quantization (used by `quantize` command)
626# ─────────────────────────────────────────────────────────────────────
627[quantize]
628# Target quantization type
629# Options: q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k, q5_k, q6_k
630output_type = "q4_0"
631
632# Number of threads (default: all available cores)
633# threads = 8
634
635# ─────────────────────────────────────────────────────────────────────
636# Benchmarking (used by `bench` command)
637# ─────────────────────────────────────────────────────────────────────
638[bench]
639# Number of prompt tokens to process
640n_prompt = 512
641
642# Number of tokens to generate
643n_gen = 128
644
645# Number of repetitions for averaging results
646repetitions = 3
647
648# Number of threads (default: all available cores)
649# threads = 4
650
651# ─────────────────────────────────────────────────────────────────────
652# Embeddings (used by `embed` command)
653# ─────────────────────────────────────────────────────────────────────
654[embed]
655# Output format: "json" or "raw"
656format = "json"
657"#
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[test]
665 fn test_default_config() {
666 let config = Config::default();
667 assert_eq!(config.generation.temperature, 0.7);
668 assert_eq!(config.generation.top_k, 40);
669 assert_eq!(config.generation.top_p, 0.95);
670 assert_eq!(config.generation.max_tokens, 512);
671 assert_eq!(config.server.port, 8080);
672 assert!(!config.model.gpu);
673 }
674
675 #[test]
676 fn test_roundtrip_toml() {
677 let config = Config {
678 model: ModelSection {
679 path: Some("/tmp/test.gguf".to_string()),
680 gpu: true,
681 ..Default::default()
682 },
683 generation: GenerationSection {
684 temperature: 0.5,
685 top_k: 50,
686 seed: Some(42),
687 ..Default::default()
688 },
689 ..Default::default()
690 };
691
692 let toml_str = toml::to_string_pretty(&config).unwrap();
693 let parsed: Config = toml::from_str(&toml_str).unwrap();
694
695 assert_eq!(parsed.model.path.as_deref(), Some("/tmp/test.gguf"));
696 assert!(parsed.model.gpu);
697 assert_eq!(parsed.generation.temperature, 0.5);
698 assert_eq!(parsed.generation.top_k, 50);
699 assert_eq!(parsed.generation.seed, Some(42));
700 }
701
702 #[test]
703 fn test_to_engine_config() {
704 let config = Config {
705 model: ModelSection {
706 path: Some("/models/llama.gguf".to_string()),
707 gpu: true,
708 ..Default::default()
709 },
710 generation: GenerationSection {
711 temperature: 0.3,
712 max_tokens: 1024,
713 seed: Some(123),
714 ..Default::default()
715 },
716 ..Default::default()
717 };
718
719 let engine = config.to_engine_config(None);
720 assert_eq!(engine.model_path, "/models/llama.gguf");
721 assert_eq!(engine.temperature, 0.3);
722 assert_eq!(engine.max_tokens, 1024);
723 assert_eq!(engine.seed, Some(123));
724 assert!(engine.use_gpu);
725 }
726
727 #[test]
728 fn test_model_path_override() {
729 let config = Config {
730 model: ModelSection {
731 path: Some("/config/model.gguf".to_string()),
732 ..Default::default()
733 },
734 ..Default::default()
735 };
736
737 let engine = config.to_engine_config(Some("/cli/model.gguf"));
739 assert_eq!(engine.model_path, "/cli/model.gguf");
740
741 let engine = config.to_engine_config(None);
743 assert_eq!(engine.model_path, "/config/model.gguf");
744 }
745
746 #[test]
747 fn test_chat_overrides() {
748 let config = Config {
749 generation: GenerationSection {
750 temperature: 0.8,
751 max_tokens: 256,
752 ..Default::default()
753 },
754 chat: ChatSection {
755 max_tokens: Some(1024),
756 temperature: Some(0.5),
757 ..Default::default()
758 },
759 ..Default::default()
760 };
761
762 let engine = config.to_chat_engine_config(None);
763 assert_eq!(engine.max_tokens, 1024); assert_eq!(engine.temperature, 0.5); assert_eq!(engine.top_k, 40); }
767
768 #[test]
769 fn test_parse_partial_toml() {
770 let toml_str = r#"
771[model]
772path = "/my/model.gguf"
773
774[generation]
775temperature = 0.3
776"#;
777
778 let config: Config = toml::from_str(toml_str).unwrap();
779 assert_eq!(config.model.path.as_deref(), Some("/my/model.gguf"));
780 assert_eq!(config.generation.temperature, 0.3);
781 assert_eq!(config.generation.top_k, 40);
783 assert_eq!(config.server.port, 8080);
784 }
785
786 #[test]
787 fn test_example_config_parses() {
788 let example = example_config();
789 assert!(example.contains("[model]"));
790 assert!(example.contains("[generation]"));
791 assert!(example.contains("[chat]"));
792 assert!(example.contains("[server]"));
793 }
794
795 #[test]
796 fn test_parse_kv_cache_type() {
797 use crate::model::KVCacheType;
798 assert_eq!(parse_kv_cache_type("f32"), KVCacheType::F32);
799 assert_eq!(parse_kv_cache_type("turboquant2"), KVCacheType::TurboQuantMSE { bits: 2 });
800 assert_eq!(parse_kv_cache_type("turboquant3"), KVCacheType::TurboQuantMSE { bits: 3 });
801 assert_eq!(parse_kv_cache_type("turboquant2-qjl"), KVCacheType::TurboQuantProd { bits: 2 });
802 assert_eq!(parse_kv_cache_type("turboquant3-qjl"), KVCacheType::TurboQuantProd { bits: 3 });
803 assert_eq!(parse_kv_cache_type(""), KVCacheType::F32);
804 }
805}
806
807pub fn parse_kv_cache_type(s: &str) -> crate::model::KVCacheType {
809 use crate::model::KVCacheType;
810 match s.to_lowercase().as_str() {
811 "turboquant2" | "tq2" => KVCacheType::TurboQuantMSE { bits: 2 },
812 "turboquant3" | "tq3" => KVCacheType::TurboQuantMSE { bits: 3 },
813 "turboquant2-qjl" | "tq2-qjl" => KVCacheType::TurboQuantProd { bits: 2 },
814 "turboquant3-qjl" | "tq3-qjl" => KVCacheType::TurboQuantProd { bits: 3 },
815 _ => KVCacheType::F32,
816 }
817}