Skip to main content

autoagents_llamacpp/
config.rs

1//! Configuration structures for llama.cpp provider.
2
3use crate::models::ModelSource;
4use llama_cpp_2::model::params::LlamaSplitMode;
5use serde::{Deserialize, Serialize};
6
7/// Serializable split mode wrapper for llama.cpp.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum LlamaCppSplitMode {
10    /// Single device.
11    None,
12    /// Split layers and KV across GPUs.
13    Layer,
14    /// Split layers and KV across GPUs, use tensor parallelism if supported.
15    Row,
16}
17
18impl From<LlamaCppSplitMode> for LlamaSplitMode {
19    fn from(value: LlamaCppSplitMode) -> Self {
20        match value {
21            LlamaCppSplitMode::None => LlamaSplitMode::None,
22            LlamaCppSplitMode::Layer => LlamaSplitMode::Layer,
23            LlamaCppSplitMode::Row => LlamaSplitMode::Row,
24        }
25    }
26}
27
28/// Reasoning extraction format for llama.cpp OpenAI-compatible parsing.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum LlamaCppReasoningFormat {
32    /// Disable reasoning extraction into `reasoning_content`.
33    None,
34    /// Let llama.cpp auto-detect the model/template strategy.
35    Auto,
36    /// Parse DeepSeek/Qwen-style thinking into `reasoning_content`.
37    Deepseek,
38    /// Legacy DeepSeek behavior.
39    DeepseekLegacy,
40}
41
42impl LlamaCppReasoningFormat {
43    /// Convert to llama.cpp reasoning format string.
44    pub fn as_str(self) -> Option<&'static str> {
45        match self {
46            Self::None => None,
47            Self::Auto => Some("auto"),
48            Self::Deepseek => Some("deepseek"),
49            Self::DeepseekLegacy => Some("deepseek_legacy"),
50        }
51    }
52}
53
54/// Complete configuration for LlamaCppProvider.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct LlamaCppConfig {
57    /// Model source (GGUF path).
58    pub model_source: ModelSource,
59
60    /// Optional chat template name or inline template.
61    pub chat_template: Option<String>,
62
63    /// Optional system prompt to prepend if no system message exists.
64    pub system_prompt: Option<String>,
65
66    /// Force JSON grammar enforcement even without a structured output schema.
67    pub force_json_grammar: bool,
68
69    /// Reasoning extraction mode for structured `reasoning_content`.
70    pub reasoning_format: Option<LlamaCppReasoningFormat>,
71
72    /// Optional `chat_template_kwargs` object passed to llama.cpp's OpenAI template API.
73    ///
74    /// Expected shape:
75    /// `{ "chat_template_kwargs": { ... } }`
76    pub extra_body: Option<serde_json::Value>,
77
78    /// Optional HuggingFace cache directory (defaults to HF_HOME or ~/.cache/huggingface/hub).
79    pub model_dir: Option<String>,
80
81    /// Optional HuggingFace filename override (GGUF file).
82    pub hf_filename: Option<String>,
83
84    /// Optional HuggingFace revision (defaults to "main").
85    pub hf_revision: Option<String>,
86
87    /// Optional multimodal projection file for MTMD models.
88    pub mmproj_path: Option<String>,
89
90    /// Optional MTMD media marker override.
91    pub media_marker: Option<String>,
92
93    /// Enable GPU offload for MTMD projection.
94    pub mmproj_use_gpu: Option<bool>,
95
96    /// Maximum tokens to generate.
97    pub max_tokens: Option<u32>,
98
99    /// Sampling temperature (0.0 - 2.0).
100    pub temperature: Option<f32>,
101
102    /// Top-p sampling parameter.
103    pub top_p: Option<f32>,
104
105    /// Top-k sampling parameter.
106    pub top_k: Option<u32>,
107
108    /// Repeat penalty (1.0 disables).
109    pub repeat_penalty: Option<f32>,
110
111    /// Penalize frequency of tokens (0.0 disables).
112    pub frequency_penalty: Option<f32>,
113
114    /// Penalize presence of tokens (0.0 disables).
115    pub presence_penalty: Option<f32>,
116
117    /// Number of tokens to consider for penalties (None = default 64).
118    pub repeat_last_n: Option<i32>,
119
120    /// RNG seed for sampling.
121    pub seed: Option<u32>,
122
123    /// Context size override.
124    pub n_ctx: Option<u32>,
125
126    /// Batch size override.
127    pub n_batch: Option<u32>,
128
129    /// Micro-batch size override.
130    pub n_ubatch: Option<u32>,
131
132    /// Number of threads for prompt evaluation.
133    pub n_threads: Option<i32>,
134
135    /// Number of threads for batch evaluation.
136    pub n_threads_batch: Option<i32>,
137
138    /// Number of GPU layers to offload.
139    pub n_gpu_layers: Option<u32>,
140
141    /// Main GPU index.
142    pub main_gpu: Option<i32>,
143
144    /// Split mode for multi-GPU.
145    pub split_mode: Option<LlamaCppSplitMode>,
146
147    /// Enable memory lock (mlock) if supported.
148    pub use_mlock: Option<bool>,
149
150    /// Explicit device indices for offload.
151    pub devices: Option<Vec<usize>>,
152}
153
154impl Default for LlamaCppConfig {
155    fn default() -> Self {
156        Self {
157            model_source: ModelSource::Gguf {
158                model_path: String::default(),
159            },
160            chat_template: None,
161            system_prompt: None,
162            force_json_grammar: false,
163            reasoning_format: None,
164            extra_body: None,
165            model_dir: None,
166            hf_filename: None,
167            hf_revision: None,
168            mmproj_path: None,
169            media_marker: None,
170            mmproj_use_gpu: None,
171            max_tokens: Some(512),
172            temperature: Some(0.7),
173            top_p: None,
174            top_k: None,
175            repeat_penalty: None,
176            frequency_penalty: None,
177            presence_penalty: None,
178            repeat_last_n: None,
179            seed: None,
180            n_ctx: None,
181            n_batch: None,
182            n_ubatch: None,
183            n_threads: None,
184            n_threads_batch: None,
185            n_gpu_layers: None,
186            main_gpu: None,
187            split_mode: None,
188            use_mlock: None,
189            devices: None,
190        }
191    }
192}
193
194/// Builder for LlamaCppConfig.
195#[derive(Debug, Default)]
196pub struct LlamaCppConfigBuilder {
197    config: LlamaCppConfig,
198}
199
200impl LlamaCppConfigBuilder {
201    /// Create a new builder with default configuration.
202    pub fn new() -> Self {
203        Self::default()
204    }
205
206    /// Set the model source.
207    pub fn model_source(mut self, source: ModelSource) -> Self {
208        self.config.model_source = source;
209        self
210    }
211
212    /// Set the model path for a local GGUF model.
213    pub fn model_path(mut self, path: impl Into<String>) -> Self {
214        self.config.model_source = ModelSource::gguf(path);
215        self
216    }
217
218    /// Set chat template.
219    pub fn chat_template(mut self, template: impl Into<String>) -> Self {
220        self.config.chat_template = Some(template.into());
221        self
222    }
223
224    /// Set system prompt.
225    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
226        self.config.system_prompt = Some(prompt.into());
227        self
228    }
229
230    /// Force JSON grammar enforcement even without a structured output schema.
231    pub fn force_json_grammar(mut self, force: bool) -> Self {
232        self.config.force_json_grammar = force;
233        self
234    }
235
236    /// Set reasoning extraction format.
237    pub fn reasoning_format(mut self, format: LlamaCppReasoningFormat) -> Self {
238        self.config.reasoning_format = Some(format);
239        self
240    }
241
242    /// Set optional `chat_template_kwargs` payload for llama.cpp OpenAI template rendering.
243    pub fn extra_body(mut self, extra_body: impl Serialize) -> Self {
244        self.config.extra_body = serde_json::to_value(extra_body).ok();
245        self
246    }
247
248    /// Set the HuggingFace cache directory.
249    pub fn model_dir(mut self, dir: impl Into<String>) -> Self {
250        self.config.model_dir = Some(dir.into());
251        self
252    }
253
254    /// Set the HuggingFace filename (GGUF file).
255    pub fn hf_filename(mut self, filename: impl Into<String>) -> Self {
256        self.config.hf_filename = Some(filename.into());
257        self
258    }
259
260    /// Set the HuggingFace revision.
261    pub fn hf_revision(mut self, revision: impl Into<String>) -> Self {
262        self.config.hf_revision = Some(revision.into());
263        self
264    }
265
266    /// Set the multimodal projection (mmproj) file path.
267    pub fn mmproj_path(mut self, path: impl Into<String>) -> Self {
268        self.config.mmproj_path = Some(path.into());
269        self
270    }
271
272    /// Set MTMD media marker.
273    pub fn media_marker(mut self, marker: impl Into<String>) -> Self {
274        self.config.media_marker = Some(marker.into());
275        self
276    }
277
278    /// Enable or disable GPU offload for MTMD projection.
279    pub fn mmproj_use_gpu(mut self, use_gpu: bool) -> Self {
280        self.config.mmproj_use_gpu = Some(use_gpu);
281        self
282    }
283
284    /// Set maximum tokens to generate.
285    pub fn max_tokens(mut self, tokens: u32) -> Self {
286        self.config.max_tokens = Some(tokens);
287        self
288    }
289
290    /// Set sampling temperature.
291    pub fn temperature(mut self, temp: f32) -> Self {
292        self.config.temperature = Some(temp);
293        self
294    }
295
296    /// Set top-p sampling parameter.
297    pub fn top_p(mut self, p: f32) -> Self {
298        self.config.top_p = Some(p);
299        self
300    }
301
302    /// Set top-k sampling parameter.
303    pub fn top_k(mut self, k: u32) -> Self {
304        self.config.top_k = Some(k);
305        self
306    }
307
308    /// Set repeat penalty.
309    pub fn repeat_penalty(mut self, penalty: f32) -> Self {
310        self.config.repeat_penalty = Some(penalty);
311        self
312    }
313
314    /// Set frequency penalty.
315    pub fn frequency_penalty(mut self, penalty: f32) -> Self {
316        self.config.frequency_penalty = Some(penalty);
317        self
318    }
319
320    /// Set presence penalty.
321    pub fn presence_penalty(mut self, penalty: f32) -> Self {
322        self.config.presence_penalty = Some(penalty);
323        self
324    }
325
326    /// Set repeat last N for penalties.
327    pub fn repeat_last_n(mut self, last_n: i32) -> Self {
328        self.config.repeat_last_n = Some(last_n);
329        self
330    }
331
332    /// Set sampling seed.
333    pub fn seed(mut self, seed: u32) -> Self {
334        self.config.seed = Some(seed);
335        self
336    }
337
338    /// Set context size.
339    pub fn n_ctx(mut self, n_ctx: u32) -> Self {
340        self.config.n_ctx = Some(n_ctx);
341        self
342    }
343
344    /// Set batch size.
345    pub fn n_batch(mut self, n_batch: u32) -> Self {
346        self.config.n_batch = Some(n_batch);
347        self
348    }
349
350    /// Set micro-batch size.
351    pub fn n_ubatch(mut self, n_ubatch: u32) -> Self {
352        self.config.n_ubatch = Some(n_ubatch);
353        self
354    }
355
356    /// Set number of threads for prompt evaluation.
357    pub fn n_threads(mut self, n_threads: i32) -> Self {
358        self.config.n_threads = Some(n_threads);
359        self
360    }
361
362    /// Set number of threads for batch evaluation.
363    pub fn n_threads_batch(mut self, n_threads: i32) -> Self {
364        self.config.n_threads_batch = Some(n_threads);
365        self
366    }
367
368    /// Set number of GPU layers to offload.
369    pub fn n_gpu_layers(mut self, layers: u32) -> Self {
370        self.config.n_gpu_layers = Some(layers);
371        self
372    }
373
374    /// Set main GPU index.
375    pub fn main_gpu(mut self, main_gpu: i32) -> Self {
376        self.config.main_gpu = Some(main_gpu);
377        self
378    }
379
380    /// Set split mode.
381    pub fn split_mode(mut self, mode: LlamaCppSplitMode) -> Self {
382        self.config.split_mode = Some(mode);
383        self
384    }
385
386    /// Enable memory lock.
387    pub fn use_mlock(mut self, use_mlock: bool) -> Self {
388        self.config.use_mlock = Some(use_mlock);
389        self
390    }
391
392    /// Set explicit device indices for offload.
393    pub fn devices(mut self, devices: Vec<usize>) -> Self {
394        self.config.devices = Some(devices);
395        self
396    }
397
398    /// Build the configuration.
399    pub fn build(self) -> LlamaCppConfig {
400        self.config
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_config_builder_basic() {
410        let config = LlamaCppConfigBuilder::default()
411            .model_path("model.gguf")
412            .max_tokens(1024)
413            .temperature(0.8)
414            .build();
415
416        assert_eq!(
417            config.model_source,
418            ModelSource::Gguf {
419                model_path: "model.gguf".to_string(),
420            }
421        );
422        assert_eq!(config.max_tokens, Some(1024));
423        assert_eq!(config.temperature, Some(0.8));
424    }
425
426    #[test]
427    fn test_config_builder_optional_flags() {
428        let config = LlamaCppConfigBuilder::default()
429            .model_path("model.gguf")
430            .force_json_grammar(true)
431            .reasoning_format(LlamaCppReasoningFormat::Deepseek)
432            .extra_body(serde_json::json!({
433                "chat_template_kwargs": {
434                    "enable_thinking": true
435                }
436            }))
437            .mmproj_use_gpu(true)
438            .split_mode(LlamaCppSplitMode::Layer)
439            .use_mlock(true)
440            .devices(vec![0, 1])
441            .build();
442
443        assert!(config.force_json_grammar);
444        assert_eq!(
445            config.reasoning_format,
446            Some(LlamaCppReasoningFormat::Deepseek)
447        );
448        assert_eq!(
449            config
450                .extra_body
451                .as_ref()
452                .and_then(|v| v.get("chat_template_kwargs"))
453                .and_then(|v| v.get("enable_thinking"))
454                .and_then(|v| v.as_bool()),
455            Some(true)
456        );
457        assert_eq!(config.mmproj_use_gpu, Some(true));
458        assert_eq!(config.split_mode, Some(LlamaCppSplitMode::Layer));
459        assert_eq!(config.use_mlock, Some(true));
460        assert_eq!(config.devices, Some(vec![0, 1]));
461    }
462
463    #[test]
464    fn test_config_default_reasoning_format_is_opt_in() {
465        let config = LlamaCppConfig::default();
466        assert_eq!(config.reasoning_format, None);
467    }
468
469    #[test]
470    fn test_config_builder_selected_options() {
471        let config = LlamaCppConfigBuilder::default()
472            .model_source(ModelSource::huggingface_with_filename(
473                "org/model",
474                "model.gguf",
475            ))
476            .chat_template("chat-template")
477            .system_prompt("system")
478            .model_dir("cache")
479            .hf_filename("override.gguf")
480            .hf_revision("rev1")
481            .mmproj_path("mmproj.gguf")
482            .media_marker("[IMG]")
483            .max_tokens(123)
484            .temperature(0.5)
485            .top_p(0.9)
486            .top_k(42)
487            .repeat_penalty(1.1)
488            .frequency_penalty(0.2)
489            .presence_penalty(0.3)
490            .repeat_last_n(32)
491            .seed(7)
492            .n_ctx(2048)
493            .n_batch(64)
494            .n_ubatch(8)
495            .n_threads(4)
496            .n_threads_batch(2)
497            .n_gpu_layers(3)
498            .main_gpu(1)
499            .build();
500
501        assert!(matches!(
502            config.model_source,
503            ModelSource::HuggingFace { .. }
504        ));
505        assert_eq!(config.chat_template.as_deref(), Some("chat-template"));
506        assert_eq!(config.system_prompt.as_deref(), Some("system"));
507        assert_eq!(config.model_dir.as_deref(), Some("cache"));
508        assert_eq!(config.hf_filename.as_deref(), Some("override.gguf"));
509        assert_eq!(config.hf_revision.as_deref(), Some("rev1"));
510        assert_eq!(config.mmproj_path.as_deref(), Some("mmproj.gguf"));
511        assert_eq!(config.media_marker.as_deref(), Some("[IMG]"));
512        assert_eq!(config.max_tokens, Some(123));
513        assert_eq!(config.temperature, Some(0.5));
514        assert_eq!(config.n_ctx, Some(2048));
515        assert_eq!(config.n_threads, Some(4));
516        assert_eq!(config.n_gpu_layers, Some(3));
517        assert_eq!(config.main_gpu, Some(1));
518    }
519}