use serde::{Deserialize, Serialize};
use std::path::Path;
use crate::engine::EngineConfig;
#[derive(thiserror::Error, Debug)]
pub enum ConfigError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("TOML parse error: {0}")]
Toml(#[from] toml::de::Error),
#[error("TOML serialize error: {0}")]
TomlSerialize(#[from] toml::ser::Error),
#[error("Config error: {0}")]
Other(String),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct Config {
pub model: ModelSection,
pub generation: GenerationSection,
pub chat: ChatSection,
pub server: ServerSection,
pub quantize: QuantizeSection,
pub bench: BenchSection,
pub embed: EmbedSection,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
#[derive(Default)]
pub struct ModelSection {
pub path: Option<String>,
pub gpu: bool,
pub kv_cache_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct GenerationSection {
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub repeat_penalty: f32,
pub max_tokens: usize,
pub seed: Option<u64>,
pub max_context_len: usize,
}
impl Default for GenerationSection {
fn default() -> Self {
Self {
temperature: 0.7,
top_k: 40,
top_p: 0.95,
repeat_penalty: 1.1,
max_tokens: 512,
seed: None,
max_context_len: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct ChatSection {
pub system_prompt: Option<String>,
pub max_tokens: Option<usize>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<usize>,
pub repeat_penalty: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct ServerSection {
pub host: String,
pub port: u16,
pub rag_database_url: Option<String>,
pub rag_config: Option<String>,
}
impl Default for ServerSection {
fn default() -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 8080,
rag_database_url: None,
rag_config: None,
}
}
}
impl ServerSection {
pub fn host_url(&self) -> Option<String> {
if self.host == "127.0.0.1" || self.host == "localhost" || self.host == "0.0.0.0" {
None
} else {
Some(format!("http://{}:{}", self.host, self.port))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct QuantizeSection {
pub output_type: String,
pub threads: Option<usize>,
}
impl Default for QuantizeSection {
fn default() -> Self {
Self {
output_type: "q4_0".to_string(),
threads: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct BenchSection {
pub n_prompt: usize,
pub n_gen: usize,
pub repetitions: usize,
pub threads: Option<usize>,
}
impl Default for BenchSection {
fn default() -> Self {
Self {
n_prompt: 512,
n_gen: 128,
repetitions: 3,
threads: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct EmbedSection {
pub format: String,
}
impl Default for EmbedSection {
fn default() -> Self {
Self {
format: "json".to_string(),
}
}
}
pub const DEFAULT_CONFIG_PATHS: &[&str] = &[
"llama-gguf.toml",
"config/llama-gguf.toml",
".llama-gguf.toml",
];
impl Config {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(path.as_ref())?;
let config: Self = toml::from_str(&content)?;
Ok(config)
}
pub fn from_env() -> Self {
let mut config = Self::default();
if let Ok(path) = std::env::var("LLAMA_MODEL_PATH") {
config.model.path = Some(path);
}
if let Ok(gpu) = std::env::var("LLAMA_GPU") {
config.model.gpu = matches!(gpu.to_lowercase().as_str(), "1" | "true" | "yes");
}
if let Ok(val) = std::env::var("LLAMA_TEMPERATURE")
&& let Ok(v) = val.parse()
{
config.generation.temperature = v;
}
if let Ok(val) = std::env::var("LLAMA_TOP_K")
&& let Ok(v) = val.parse()
{
config.generation.top_k = v;
}
if let Ok(val) = std::env::var("LLAMA_TOP_P")
&& let Ok(v) = val.parse()
{
config.generation.top_p = v;
}
if let Ok(val) = std::env::var("LLAMA_REPEAT_PENALTY")
&& let Ok(v) = val.parse()
{
config.generation.repeat_penalty = v;
}
if let Ok(val) = std::env::var("LLAMA_MAX_TOKENS")
&& let Ok(v) = val.parse()
{
config.generation.max_tokens = v;
}
if let Ok(val) = std::env::var("LLAMA_SEED")
&& let Ok(v) = val.parse()
{
config.generation.seed = Some(v);
}
if let Ok(val) = std::env::var("LLAMA_HOST") {
config.server.host = val;
}
if let Ok(val) = std::env::var("LLAMA_PORT")
&& let Ok(v) = val.parse()
{
config.server.port = v;
}
if let Ok(val) = std::env::var("LLAMA_SYSTEM_PROMPT") {
config.chat.system_prompt = Some(val);
}
config
}
pub fn load(config_path: Option<impl AsRef<Path>>) -> Result<Self, ConfigError> {
let mut config = Self::default();
if let Some(path) = config_path {
let p = path.as_ref();
if p.exists() {
config = Self::from_file(p)?;
} else {
return Err(ConfigError::Other(format!(
"Config file not found: {}",
p.display()
)));
}
} else {
for path in DEFAULT_CONFIG_PATHS {
if Path::new(path).exists() {
config = Self::from_file(path)?;
break;
}
}
}
config.apply_env();
Ok(config)
}
pub fn apply_env(&mut self) {
if let Ok(path) = std::env::var("LLAMA_MODEL_PATH") {
self.model.path = Some(path);
}
if let Ok(gpu) = std::env::var("LLAMA_GPU") {
self.model.gpu = matches!(gpu.to_lowercase().as_str(), "1" | "true" | "yes");
}
if let Ok(val) = std::env::var("LLAMA_TEMPERATURE")
&& let Ok(v) = val.parse()
{
self.generation.temperature = v;
}
if let Ok(val) = std::env::var("LLAMA_TOP_K")
&& let Ok(v) = val.parse()
{
self.generation.top_k = v;
}
if let Ok(val) = std::env::var("LLAMA_TOP_P")
&& let Ok(v) = val.parse()
{
self.generation.top_p = v;
}
if let Ok(val) = std::env::var("LLAMA_REPEAT_PENALTY")
&& let Ok(v) = val.parse()
{
self.generation.repeat_penalty = v;
}
if let Ok(val) = std::env::var("LLAMA_MAX_TOKENS")
&& let Ok(v) = val.parse()
{
self.generation.max_tokens = v;
}
if let Ok(val) = std::env::var("LLAMA_SEED")
&& let Ok(v) = val.parse()
{
self.generation.seed = Some(v);
}
if let Ok(val) = std::env::var("LLAMA_HOST") {
self.server.host = val;
}
if let Ok(val) = std::env::var("LLAMA_PORT")
&& let Ok(v) = val.parse()
{
self.server.port = v;
}
if let Ok(val) = std::env::var("LLAMA_SYSTEM_PROMPT") {
self.chat.system_prompt = Some(val);
}
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<(), ConfigError> {
let content = toml::to_string_pretty(self)?;
std::fs::write(path, content)?;
Ok(())
}
pub fn to_engine_config(&self, model_path_override: Option<&str>) -> EngineConfig {
let model_path = model_path_override
.map(|s| s.to_string())
.or_else(|| self.model.path.clone())
.unwrap_or_default();
EngineConfig {
model_path,
tokenizer_path: None,
temperature: self.generation.temperature,
top_k: self.generation.top_k,
top_p: self.generation.top_p,
repeat_penalty: self.generation.repeat_penalty,
max_tokens: self.generation.max_tokens,
seed: self.generation.seed,
use_gpu: self.model.gpu,
max_context_len: None,
#[cfg(feature = "hailo")]
hailo_config: None,
kv_cache_type: parse_kv_cache_type(&self.model.kv_cache_type),
}
}
pub fn to_chat_engine_config(&self, model_path_override: Option<&str>) -> EngineConfig {
let mut config = self.to_engine_config(model_path_override);
if let Some(max_tokens) = self.chat.max_tokens {
config.max_tokens = max_tokens;
}
if let Some(temperature) = self.chat.temperature {
config.temperature = temperature;
}
if let Some(top_p) = self.chat.top_p {
config.top_p = top_p;
}
if let Some(top_k) = self.chat.top_k {
config.top_k = top_k;
}
if let Some(repeat_penalty) = self.chat.repeat_penalty {
config.repeat_penalty = repeat_penalty;
}
config
}
}
pub fn example_config() -> &'static str {
r#"# llama-gguf configuration
# All values shown are defaults. Uncomment and modify as needed.
#
# Precedence: CLI arguments > environment variables > this file > defaults
# ─────────────────────────────────────────────────────────────────────
# Model
# ─────────────────────────────────────────────────────────────────────
[model]
# Path to the GGUF model file (can also use LLAMA_MODEL_PATH env var)
# path = "/path/to/model.gguf"
# Use GPU acceleration (CUDA/Metal/Vulkan)
# Also: LLAMA_GPU=1
gpu = false
# ─────────────────────────────────────────────────────────────────────
# Generation / Sampling Parameters
# Used by: run, chat, serve
# ─────────────────────────────────────────────────────────────────────
[generation]
# Sampling temperature (0.0 = greedy/deterministic, higher = more random)
# Also: LLAMA_TEMPERATURE
temperature = 0.7
# Top-K sampling: only consider the K most likely next tokens (0 = disabled)
# Also: LLAMA_TOP_K
top_k = 40
# Top-P (nucleus) sampling: cumulative probability cutoff
# Also: LLAMA_TOP_P
top_p = 0.95
# Repetition penalty (1.0 = no penalty, higher = less repetition)
# Also: LLAMA_REPEAT_PENALTY
repeat_penalty = 1.1
# Default maximum tokens to generate per request
# Also: LLAMA_MAX_TOKENS
max_tokens = 512
# Random seed for reproducible generation (comment out for random)
# Also: LLAMA_SEED
# seed = 42
# ─────────────────────────────────────────────────────────────────────
# Chat Mode Overrides
# Values here override [generation] when using the `chat` command.
# Omitted values fall back to [generation].
# ─────────────────────────────────────────────────────────────────────
[chat]
# Default system prompt for chat sessions
# Also: LLAMA_SYSTEM_PROMPT
# system_prompt = "You are a helpful AI assistant."
# Override generation settings for chat specifically
# max_tokens = 1024
# temperature = 0.7
# top_p = 0.9
# top_k = 40
# repeat_penalty = 1.1
# ─────────────────────────────────────────────────────────────────────
# HTTP Server (used by `serve` command)
# ─────────────────────────────────────────────────────────────────────
[server]
# Host address to bind to
# Also: LLAMA_HOST
host = "127.0.0.1"
# Port to listen on
# Also: LLAMA_PORT
port = 8080
# PostgreSQL/pgvector URL for RAG (requires `rag` feature)
# Also: RAG_DATABASE_URL
# rag_database_url = "postgres://user:pass@localhost:5432/mydb"
# Path to separate RAG config file
# rag_config = "rag.toml"
# ─────────────────────────────────────────────────────────────────────
# Quantization (used by `quantize` command)
# ─────────────────────────────────────────────────────────────────────
[quantize]
# Target quantization type
# Options: q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k, q5_k, q6_k
output_type = "q4_0"
# Number of threads (default: all available cores)
# threads = 8
# ─────────────────────────────────────────────────────────────────────
# Benchmarking (used by `bench` command)
# ─────────────────────────────────────────────────────────────────────
[bench]
# Number of prompt tokens to process
n_prompt = 512
# Number of tokens to generate
n_gen = 128
# Number of repetitions for averaging results
repetitions = 3
# Number of threads (default: all available cores)
# threads = 4
# ─────────────────────────────────────────────────────────────────────
# Embeddings (used by `embed` command)
# ─────────────────────────────────────────────────────────────────────
[embed]
# Output format: "json" or "raw"
format = "json"
"#
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.generation.temperature, 0.7);
assert_eq!(config.generation.top_k, 40);
assert_eq!(config.generation.top_p, 0.95);
assert_eq!(config.generation.max_tokens, 512);
assert_eq!(config.server.port, 8080);
assert!(!config.model.gpu);
}
#[test]
fn test_roundtrip_toml() {
let config = Config {
model: ModelSection {
path: Some("/tmp/test.gguf".to_string()),
gpu: true,
..Default::default()
},
generation: GenerationSection {
temperature: 0.5,
top_k: 50,
seed: Some(42),
..Default::default()
},
..Default::default()
};
let toml_str = toml::to_string_pretty(&config).unwrap();
let parsed: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.model.path.as_deref(), Some("/tmp/test.gguf"));
assert!(parsed.model.gpu);
assert_eq!(parsed.generation.temperature, 0.5);
assert_eq!(parsed.generation.top_k, 50);
assert_eq!(parsed.generation.seed, Some(42));
}
#[test]
fn test_to_engine_config() {
let config = Config {
model: ModelSection {
path: Some("/models/llama.gguf".to_string()),
gpu: true,
..Default::default()
},
generation: GenerationSection {
temperature: 0.3,
max_tokens: 1024,
seed: Some(123),
..Default::default()
},
..Default::default()
};
let engine = config.to_engine_config(None);
assert_eq!(engine.model_path, "/models/llama.gguf");
assert_eq!(engine.temperature, 0.3);
assert_eq!(engine.max_tokens, 1024);
assert_eq!(engine.seed, Some(123));
assert!(engine.use_gpu);
}
#[test]
fn test_model_path_override() {
let config = Config {
model: ModelSection {
path: Some("/config/model.gguf".to_string()),
..Default::default()
},
..Default::default()
};
let engine = config.to_engine_config(Some("/cli/model.gguf"));
assert_eq!(engine.model_path, "/cli/model.gguf");
let engine = config.to_engine_config(None);
assert_eq!(engine.model_path, "/config/model.gguf");
}
#[test]
fn test_chat_overrides() {
let config = Config {
generation: GenerationSection {
temperature: 0.8,
max_tokens: 256,
..Default::default()
},
chat: ChatSection {
max_tokens: Some(1024),
temperature: Some(0.5),
..Default::default()
},
..Default::default()
};
let engine = config.to_chat_engine_config(None);
assert_eq!(engine.max_tokens, 1024); assert_eq!(engine.temperature, 0.5); assert_eq!(engine.top_k, 40); }
#[test]
fn test_parse_partial_toml() {
let toml_str = r#"
[model]
path = "/my/model.gguf"
[generation]
temperature = 0.3
"#;
let config: Config = toml::from_str(toml_str).unwrap();
assert_eq!(config.model.path.as_deref(), Some("/my/model.gguf"));
assert_eq!(config.generation.temperature, 0.3);
assert_eq!(config.generation.top_k, 40);
assert_eq!(config.server.port, 8080);
}
#[test]
fn test_example_config_parses() {
let example = example_config();
assert!(example.contains("[model]"));
assert!(example.contains("[generation]"));
assert!(example.contains("[chat]"));
assert!(example.contains("[server]"));
}
#[test]
fn test_parse_kv_cache_type() {
use crate::model::KVCacheType;
assert_eq!(parse_kv_cache_type("f32"), KVCacheType::F32);
assert_eq!(parse_kv_cache_type("turboquant2"), KVCacheType::TurboQuantMSE { bits: 2 });
assert_eq!(parse_kv_cache_type("turboquant3"), KVCacheType::TurboQuantMSE { bits: 3 });
assert_eq!(parse_kv_cache_type("turboquant2-qjl"), KVCacheType::TurboQuantProd { bits: 2 });
assert_eq!(parse_kv_cache_type("turboquant3-qjl"), KVCacheType::TurboQuantProd { bits: 3 });
assert_eq!(parse_kv_cache_type(""), KVCacheType::F32);
}
}
pub fn parse_kv_cache_type(s: &str) -> crate::model::KVCacheType {
use crate::model::KVCacheType;
match s.to_lowercase().as_str() {
"turboquant2" | "tq2" => KVCacheType::TurboQuantMSE { bits: 2 },
"turboquant3" | "tq3" => KVCacheType::TurboQuantMSE { bits: 3 },
"turboquant2-qjl" | "tq2-qjl" => KVCacheType::TurboQuantProd { bits: 2 },
"turboquant3-qjl" | "tq3-qjl" => KVCacheType::TurboQuantProd { bits: 3 },
_ => KVCacheType::F32,
}
}