1use crate::models::ModelSource;
4use llama_cpp_2::model::params::LlamaSplitMode;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum LlamaCppSplitMode {
10 None,
12 Layer,
14 Row,
16}
17
18impl From<LlamaCppSplitMode> for LlamaSplitMode {
19 fn from(value: LlamaCppSplitMode) -> Self {
20 match value {
21 LlamaCppSplitMode::None => LlamaSplitMode::None,
22 LlamaCppSplitMode::Layer => LlamaSplitMode::Layer,
23 LlamaCppSplitMode::Row => LlamaSplitMode::Row,
24 }
25 }
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum LlamaCppReasoningFormat {
32 None,
34 Auto,
36 Deepseek,
38 DeepseekLegacy,
40}
41
42impl LlamaCppReasoningFormat {
43 pub fn as_str(self) -> Option<&'static str> {
45 match self {
46 Self::None => None,
47 Self::Auto => Some("auto"),
48 Self::Deepseek => Some("deepseek"),
49 Self::DeepseekLegacy => Some("deepseek_legacy"),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct LlamaCppConfig {
57 pub model_source: ModelSource,
59
60 pub chat_template: Option<String>,
62
63 pub system_prompt: Option<String>,
65
66 pub force_json_grammar: bool,
68
69 pub reasoning_format: Option<LlamaCppReasoningFormat>,
71
72 pub extra_body: Option<serde_json::Value>,
77
78 pub model_dir: Option<String>,
80
81 pub hf_filename: Option<String>,
83
84 pub hf_revision: Option<String>,
86
87 pub mmproj_path: Option<String>,
89
90 pub media_marker: Option<String>,
92
93 pub mmproj_use_gpu: Option<bool>,
95
96 pub max_tokens: Option<u32>,
98
99 pub temperature: Option<f32>,
101
102 pub top_p: Option<f32>,
104
105 pub top_k: Option<u32>,
107
108 pub repeat_penalty: Option<f32>,
110
111 pub frequency_penalty: Option<f32>,
113
114 pub presence_penalty: Option<f32>,
116
117 pub repeat_last_n: Option<i32>,
119
120 pub seed: Option<u32>,
122
123 pub n_ctx: Option<u32>,
125
126 pub n_batch: Option<u32>,
128
129 pub n_ubatch: Option<u32>,
131
132 pub n_threads: Option<i32>,
134
135 pub n_threads_batch: Option<i32>,
137
138 pub n_gpu_layers: Option<u32>,
140
141 pub main_gpu: Option<i32>,
143
144 pub split_mode: Option<LlamaCppSplitMode>,
146
147 pub use_mlock: Option<bool>,
149
150 pub devices: Option<Vec<usize>>,
152}
153
154impl Default for LlamaCppConfig {
155 fn default() -> Self {
156 Self {
157 model_source: ModelSource::Gguf {
158 model_path: String::default(),
159 },
160 chat_template: None,
161 system_prompt: None,
162 force_json_grammar: false,
163 reasoning_format: None,
164 extra_body: None,
165 model_dir: None,
166 hf_filename: None,
167 hf_revision: None,
168 mmproj_path: None,
169 media_marker: None,
170 mmproj_use_gpu: None,
171 max_tokens: Some(512),
172 temperature: Some(0.7),
173 top_p: None,
174 top_k: None,
175 repeat_penalty: None,
176 frequency_penalty: None,
177 presence_penalty: None,
178 repeat_last_n: None,
179 seed: None,
180 n_ctx: None,
181 n_batch: None,
182 n_ubatch: None,
183 n_threads: None,
184 n_threads_batch: None,
185 n_gpu_layers: None,
186 main_gpu: None,
187 split_mode: None,
188 use_mlock: None,
189 devices: None,
190 }
191 }
192}
193
194#[derive(Debug, Default)]
196pub struct LlamaCppConfigBuilder {
197 config: LlamaCppConfig,
198}
199
200impl LlamaCppConfigBuilder {
201 pub fn new() -> Self {
203 Self::default()
204 }
205
206 pub fn model_source(mut self, source: ModelSource) -> Self {
208 self.config.model_source = source;
209 self
210 }
211
212 pub fn model_path(mut self, path: impl Into<String>) -> Self {
214 self.config.model_source = ModelSource::gguf(path);
215 self
216 }
217
218 pub fn chat_template(mut self, template: impl Into<String>) -> Self {
220 self.config.chat_template = Some(template.into());
221 self
222 }
223
224 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
226 self.config.system_prompt = Some(prompt.into());
227 self
228 }
229
230 pub fn force_json_grammar(mut self, force: bool) -> Self {
232 self.config.force_json_grammar = force;
233 self
234 }
235
236 pub fn reasoning_format(mut self, format: LlamaCppReasoningFormat) -> Self {
238 self.config.reasoning_format = Some(format);
239 self
240 }
241
242 pub fn extra_body(mut self, extra_body: impl Serialize) -> Self {
244 self.config.extra_body = serde_json::to_value(extra_body).ok();
245 self
246 }
247
248 pub fn model_dir(mut self, dir: impl Into<String>) -> Self {
250 self.config.model_dir = Some(dir.into());
251 self
252 }
253
254 pub fn hf_filename(mut self, filename: impl Into<String>) -> Self {
256 self.config.hf_filename = Some(filename.into());
257 self
258 }
259
260 pub fn hf_revision(mut self, revision: impl Into<String>) -> Self {
262 self.config.hf_revision = Some(revision.into());
263 self
264 }
265
266 pub fn mmproj_path(mut self, path: impl Into<String>) -> Self {
268 self.config.mmproj_path = Some(path.into());
269 self
270 }
271
272 pub fn media_marker(mut self, marker: impl Into<String>) -> Self {
274 self.config.media_marker = Some(marker.into());
275 self
276 }
277
278 pub fn mmproj_use_gpu(mut self, use_gpu: bool) -> Self {
280 self.config.mmproj_use_gpu = Some(use_gpu);
281 self
282 }
283
284 pub fn max_tokens(mut self, tokens: u32) -> Self {
286 self.config.max_tokens = Some(tokens);
287 self
288 }
289
290 pub fn temperature(mut self, temp: f32) -> Self {
292 self.config.temperature = Some(temp);
293 self
294 }
295
296 pub fn top_p(mut self, p: f32) -> Self {
298 self.config.top_p = Some(p);
299 self
300 }
301
302 pub fn top_k(mut self, k: u32) -> Self {
304 self.config.top_k = Some(k);
305 self
306 }
307
308 pub fn repeat_penalty(mut self, penalty: f32) -> Self {
310 self.config.repeat_penalty = Some(penalty);
311 self
312 }
313
314 pub fn frequency_penalty(mut self, penalty: f32) -> Self {
316 self.config.frequency_penalty = Some(penalty);
317 self
318 }
319
320 pub fn presence_penalty(mut self, penalty: f32) -> Self {
322 self.config.presence_penalty = Some(penalty);
323 self
324 }
325
326 pub fn repeat_last_n(mut self, last_n: i32) -> Self {
328 self.config.repeat_last_n = Some(last_n);
329 self
330 }
331
332 pub fn seed(mut self, seed: u32) -> Self {
334 self.config.seed = Some(seed);
335 self
336 }
337
338 pub fn n_ctx(mut self, n_ctx: u32) -> Self {
340 self.config.n_ctx = Some(n_ctx);
341 self
342 }
343
344 pub fn n_batch(mut self, n_batch: u32) -> Self {
346 self.config.n_batch = Some(n_batch);
347 self
348 }
349
350 pub fn n_ubatch(mut self, n_ubatch: u32) -> Self {
352 self.config.n_ubatch = Some(n_ubatch);
353 self
354 }
355
356 pub fn n_threads(mut self, n_threads: i32) -> Self {
358 self.config.n_threads = Some(n_threads);
359 self
360 }
361
362 pub fn n_threads_batch(mut self, n_threads: i32) -> Self {
364 self.config.n_threads_batch = Some(n_threads);
365 self
366 }
367
368 pub fn n_gpu_layers(mut self, layers: u32) -> Self {
370 self.config.n_gpu_layers = Some(layers);
371 self
372 }
373
374 pub fn main_gpu(mut self, main_gpu: i32) -> Self {
376 self.config.main_gpu = Some(main_gpu);
377 self
378 }
379
380 pub fn split_mode(mut self, mode: LlamaCppSplitMode) -> Self {
382 self.config.split_mode = Some(mode);
383 self
384 }
385
386 pub fn use_mlock(mut self, use_mlock: bool) -> Self {
388 self.config.use_mlock = Some(use_mlock);
389 self
390 }
391
392 pub fn devices(mut self, devices: Vec<usize>) -> Self {
394 self.config.devices = Some(devices);
395 self
396 }
397
398 pub fn build(self) -> LlamaCppConfig {
400 self.config
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_config_builder_basic() {
410 let config = LlamaCppConfigBuilder::default()
411 .model_path("model.gguf")
412 .max_tokens(1024)
413 .temperature(0.8)
414 .build();
415
416 assert_eq!(
417 config.model_source,
418 ModelSource::Gguf {
419 model_path: "model.gguf".to_string(),
420 }
421 );
422 assert_eq!(config.max_tokens, Some(1024));
423 assert_eq!(config.temperature, Some(0.8));
424 }
425
426 #[test]
427 fn test_config_builder_optional_flags() {
428 let config = LlamaCppConfigBuilder::default()
429 .model_path("model.gguf")
430 .force_json_grammar(true)
431 .reasoning_format(LlamaCppReasoningFormat::Deepseek)
432 .extra_body(serde_json::json!({
433 "chat_template_kwargs": {
434 "enable_thinking": true
435 }
436 }))
437 .mmproj_use_gpu(true)
438 .split_mode(LlamaCppSplitMode::Layer)
439 .use_mlock(true)
440 .devices(vec![0, 1])
441 .build();
442
443 assert!(config.force_json_grammar);
444 assert_eq!(
445 config.reasoning_format,
446 Some(LlamaCppReasoningFormat::Deepseek)
447 );
448 assert_eq!(
449 config
450 .extra_body
451 .as_ref()
452 .and_then(|v| v.get("chat_template_kwargs"))
453 .and_then(|v| v.get("enable_thinking"))
454 .and_then(|v| v.as_bool()),
455 Some(true)
456 );
457 assert_eq!(config.mmproj_use_gpu, Some(true));
458 assert_eq!(config.split_mode, Some(LlamaCppSplitMode::Layer));
459 assert_eq!(config.use_mlock, Some(true));
460 assert_eq!(config.devices, Some(vec![0, 1]));
461 }
462
463 #[test]
464 fn test_config_default_reasoning_format_is_opt_in() {
465 let config = LlamaCppConfig::default();
466 assert_eq!(config.reasoning_format, None);
467 }
468
469 #[test]
470 fn test_config_builder_selected_options() {
471 let config = LlamaCppConfigBuilder::default()
472 .model_source(ModelSource::huggingface_with_filename(
473 "org/model",
474 "model.gguf",
475 ))
476 .chat_template("chat-template")
477 .system_prompt("system")
478 .model_dir("cache")
479 .hf_filename("override.gguf")
480 .hf_revision("rev1")
481 .mmproj_path("mmproj.gguf")
482 .media_marker("[IMG]")
483 .max_tokens(123)
484 .temperature(0.5)
485 .top_p(0.9)
486 .top_k(42)
487 .repeat_penalty(1.1)
488 .frequency_penalty(0.2)
489 .presence_penalty(0.3)
490 .repeat_last_n(32)
491 .seed(7)
492 .n_ctx(2048)
493 .n_batch(64)
494 .n_ubatch(8)
495 .n_threads(4)
496 .n_threads_batch(2)
497 .n_gpu_layers(3)
498 .main_gpu(1)
499 .build();
500
501 assert!(matches!(
502 config.model_source,
503 ModelSource::HuggingFace { .. }
504 ));
505 assert_eq!(config.chat_template.as_deref(), Some("chat-template"));
506 assert_eq!(config.system_prompt.as_deref(), Some("system"));
507 assert_eq!(config.model_dir.as_deref(), Some("cache"));
508 assert_eq!(config.hf_filename.as_deref(), Some("override.gguf"));
509 assert_eq!(config.hf_revision.as_deref(), Some("rev1"));
510 assert_eq!(config.mmproj_path.as_deref(), Some("mmproj.gguf"));
511 assert_eq!(config.media_marker.as_deref(), Some("[IMG]"));
512 assert_eq!(config.max_tokens, Some(123));
513 assert_eq!(config.temperature, Some(0.5));
514 assert_eq!(config.n_ctx, Some(2048));
515 assert_eq!(config.n_threads, Some(4));
516 assert_eq!(config.n_gpu_layers, Some(3));
517 assert_eq!(config.main_gpu, Some(1));
518 }
519}