use crate::models::ModelSource;
use llama_cpp_2::model::params::LlamaSplitMode;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LlamaCppSplitMode {
None,
Layer,
Row,
}
impl From<LlamaCppSplitMode> for LlamaSplitMode {
fn from(value: LlamaCppSplitMode) -> Self {
match value {
LlamaCppSplitMode::None => LlamaSplitMode::None,
LlamaCppSplitMode::Layer => LlamaSplitMode::Layer,
LlamaCppSplitMode::Row => LlamaSplitMode::Row,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlamaCppReasoningFormat {
None,
Auto,
Deepseek,
DeepseekLegacy,
}
impl LlamaCppReasoningFormat {
pub fn as_str(self) -> Option<&'static str> {
match self {
Self::None => None,
Self::Auto => Some("auto"),
Self::Deepseek => Some("deepseek"),
Self::DeepseekLegacy => Some("deepseek_legacy"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlamaCppConfig {
pub model_source: ModelSource,
pub chat_template: Option<String>,
pub system_prompt: Option<String>,
pub force_json_grammar: bool,
pub reasoning_format: Option<LlamaCppReasoningFormat>,
pub extra_body: Option<serde_json::Value>,
pub model_dir: Option<String>,
pub hf_filename: Option<String>,
pub hf_revision: Option<String>,
pub mmproj_path: Option<String>,
pub media_marker: Option<String>,
pub mmproj_use_gpu: Option<bool>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub repeat_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub repeat_last_n: Option<i32>,
pub seed: Option<u32>,
pub n_ctx: Option<u32>,
pub n_batch: Option<u32>,
pub n_ubatch: Option<u32>,
pub n_threads: Option<i32>,
pub n_threads_batch: Option<i32>,
pub n_gpu_layers: Option<u32>,
pub main_gpu: Option<i32>,
pub split_mode: Option<LlamaCppSplitMode>,
pub use_mlock: Option<bool>,
pub devices: Option<Vec<usize>>,
}
impl Default for LlamaCppConfig {
fn default() -> Self {
Self {
model_source: ModelSource::Gguf {
model_path: String::default(),
},
chat_template: None,
system_prompt: None,
force_json_grammar: false,
reasoning_format: None,
extra_body: None,
model_dir: None,
hf_filename: None,
hf_revision: None,
mmproj_path: None,
media_marker: None,
mmproj_use_gpu: None,
max_tokens: Some(512),
temperature: Some(0.7),
top_p: None,
top_k: None,
repeat_penalty: None,
frequency_penalty: None,
presence_penalty: None,
repeat_last_n: None,
seed: None,
n_ctx: None,
n_batch: None,
n_ubatch: None,
n_threads: None,
n_threads_batch: None,
n_gpu_layers: None,
main_gpu: None,
split_mode: None,
use_mlock: None,
devices: None,
}
}
}
#[derive(Debug, Default)]
pub struct LlamaCppConfigBuilder {
config: LlamaCppConfig,
}
impl LlamaCppConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn model_source(mut self, source: ModelSource) -> Self {
self.config.model_source = source;
self
}
pub fn model_path(mut self, path: impl Into<String>) -> Self {
self.config.model_source = ModelSource::gguf(path);
self
}
pub fn chat_template(mut self, template: impl Into<String>) -> Self {
self.config.chat_template = Some(template.into());
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.config.system_prompt = Some(prompt.into());
self
}
pub fn force_json_grammar(mut self, force: bool) -> Self {
self.config.force_json_grammar = force;
self
}
pub fn reasoning_format(mut self, format: LlamaCppReasoningFormat) -> Self {
self.config.reasoning_format = Some(format);
self
}
pub fn extra_body(mut self, extra_body: impl Serialize) -> Self {
self.config.extra_body = serde_json::to_value(extra_body).ok();
self
}
pub fn model_dir(mut self, dir: impl Into<String>) -> Self {
self.config.model_dir = Some(dir.into());
self
}
pub fn hf_filename(mut self, filename: impl Into<String>) -> Self {
self.config.hf_filename = Some(filename.into());
self
}
pub fn hf_revision(mut self, revision: impl Into<String>) -> Self {
self.config.hf_revision = Some(revision.into());
self
}
pub fn mmproj_path(mut self, path: impl Into<String>) -> Self {
self.config.mmproj_path = Some(path.into());
self
}
pub fn media_marker(mut self, marker: impl Into<String>) -> Self {
self.config.media_marker = Some(marker.into());
self
}
pub fn mmproj_use_gpu(mut self, use_gpu: bool) -> Self {
self.config.mmproj_use_gpu = Some(use_gpu);
self
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.config.max_tokens = Some(tokens);
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.config.temperature = Some(temp);
self
}
pub fn top_p(mut self, p: f32) -> Self {
self.config.top_p = Some(p);
self
}
pub fn top_k(mut self, k: u32) -> Self {
self.config.top_k = Some(k);
self
}
pub fn repeat_penalty(mut self, penalty: f32) -> Self {
self.config.repeat_penalty = Some(penalty);
self
}
pub fn frequency_penalty(mut self, penalty: f32) -> Self {
self.config.frequency_penalty = Some(penalty);
self
}
pub fn presence_penalty(mut self, penalty: f32) -> Self {
self.config.presence_penalty = Some(penalty);
self
}
pub fn repeat_last_n(mut self, last_n: i32) -> Self {
self.config.repeat_last_n = Some(last_n);
self
}
pub fn seed(mut self, seed: u32) -> Self {
self.config.seed = Some(seed);
self
}
pub fn n_ctx(mut self, n_ctx: u32) -> Self {
self.config.n_ctx = Some(n_ctx);
self
}
pub fn n_batch(mut self, n_batch: u32) -> Self {
self.config.n_batch = Some(n_batch);
self
}
pub fn n_ubatch(mut self, n_ubatch: u32) -> Self {
self.config.n_ubatch = Some(n_ubatch);
self
}
pub fn n_threads(mut self, n_threads: i32) -> Self {
self.config.n_threads = Some(n_threads);
self
}
pub fn n_threads_batch(mut self, n_threads: i32) -> Self {
self.config.n_threads_batch = Some(n_threads);
self
}
pub fn n_gpu_layers(mut self, layers: u32) -> Self {
self.config.n_gpu_layers = Some(layers);
self
}
pub fn main_gpu(mut self, main_gpu: i32) -> Self {
self.config.main_gpu = Some(main_gpu);
self
}
pub fn split_mode(mut self, mode: LlamaCppSplitMode) -> Self {
self.config.split_mode = Some(mode);
self
}
pub fn use_mlock(mut self, use_mlock: bool) -> Self {
self.config.use_mlock = Some(use_mlock);
self
}
pub fn devices(mut self, devices: Vec<usize>) -> Self {
self.config.devices = Some(devices);
self
}
pub fn build(self) -> LlamaCppConfig {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder_basic() {
let config = LlamaCppConfigBuilder::default()
.model_path("model.gguf")
.max_tokens(1024)
.temperature(0.8)
.build();
assert_eq!(
config.model_source,
ModelSource::Gguf {
model_path: "model.gguf".to_string(),
}
);
assert_eq!(config.max_tokens, Some(1024));
assert_eq!(config.temperature, Some(0.8));
}
#[test]
fn test_config_builder_optional_flags() {
let config = LlamaCppConfigBuilder::default()
.model_path("model.gguf")
.force_json_grammar(true)
.reasoning_format(LlamaCppReasoningFormat::Deepseek)
.extra_body(serde_json::json!({
"chat_template_kwargs": {
"enable_thinking": true
}
}))
.mmproj_use_gpu(true)
.split_mode(LlamaCppSplitMode::Layer)
.use_mlock(true)
.devices(vec![0, 1])
.build();
assert!(config.force_json_grammar);
assert_eq!(
config.reasoning_format,
Some(LlamaCppReasoningFormat::Deepseek)
);
assert_eq!(
config
.extra_body
.as_ref()
.and_then(|v| v.get("chat_template_kwargs"))
.and_then(|v| v.get("enable_thinking"))
.and_then(|v| v.as_bool()),
Some(true)
);
assert_eq!(config.mmproj_use_gpu, Some(true));
assert_eq!(config.split_mode, Some(LlamaCppSplitMode::Layer));
assert_eq!(config.use_mlock, Some(true));
assert_eq!(config.devices, Some(vec![0, 1]));
}
#[test]
fn test_config_default_reasoning_format_is_opt_in() {
let config = LlamaCppConfig::default();
assert_eq!(config.reasoning_format, None);
}
#[test]
fn test_config_builder_selected_options() {
let config = LlamaCppConfigBuilder::default()
.model_source(ModelSource::huggingface_with_filename(
"org/model",
"model.gguf",
))
.chat_template("chat-template")
.system_prompt("system")
.model_dir("cache")
.hf_filename("override.gguf")
.hf_revision("rev1")
.mmproj_path("mmproj.gguf")
.media_marker("[IMG]")
.max_tokens(123)
.temperature(0.5)
.top_p(0.9)
.top_k(42)
.repeat_penalty(1.1)
.frequency_penalty(0.2)
.presence_penalty(0.3)
.repeat_last_n(32)
.seed(7)
.n_ctx(2048)
.n_batch(64)
.n_ubatch(8)
.n_threads(4)
.n_threads_batch(2)
.n_gpu_layers(3)
.main_gpu(1)
.build();
assert!(matches!(
config.model_source,
ModelSource::HuggingFace { .. }
));
assert_eq!(config.chat_template.as_deref(), Some("chat-template"));
assert_eq!(config.system_prompt.as_deref(), Some("system"));
assert_eq!(config.model_dir.as_deref(), Some("cache"));
assert_eq!(config.hf_filename.as_deref(), Some("override.gguf"));
assert_eq!(config.hf_revision.as_deref(), Some("rev1"));
assert_eq!(config.mmproj_path.as_deref(), Some("mmproj.gguf"));
assert_eq!(config.media_marker.as_deref(), Some("[IMG]"));
assert_eq!(config.max_tokens, Some(123));
assert_eq!(config.temperature, Some(0.5));
assert_eq!(config.n_ctx, Some(2048));
assert_eq!(config.n_threads, Some(4));
assert_eq!(config.n_gpu_layers, Some(3));
assert_eq!(config.main_gpu, Some(1));
}
}