1use crate::log_options::LogOptions;
2use std::sync::OnceLock;
3use tracing_core::{Interest, Kind, Metadata, callsite, field, identify_callsite};
4
5static FIELD_NAMES: &[&str] = &["message", "module"];
6
7struct OverridableFields {
8 message: tracing::field::Field,
9 target: tracing::field::Field,
10}
11
12macro_rules! log_cs {
13 ($level:expr, $cs:ident, $meta:ident, $fields:ident, $ty:ident) => {
14 struct $ty;
15 static $cs: $ty = $ty;
16 static $meta: Metadata<'static> = Metadata::new(
17 "log event",
18 "llama-cpp-bindings",
19 $level,
20 ::core::option::Option::None,
21 ::core::option::Option::None,
22 ::core::option::Option::None,
23 field::FieldSet::new(FIELD_NAMES, identify_callsite!(&$cs)),
24 Kind::EVENT,
25 );
26 static $fields: std::sync::LazyLock<OverridableFields> = std::sync::LazyLock::new(|| {
27 let fields = $meta.fields();
28 OverridableFields {
29 message: fields
30 .field("message")
31 .expect("message field defined in FIELD_NAMES"),
32 target: fields
33 .field("module")
34 .expect("module field defined in FIELD_NAMES"),
35 }
36 });
37
38 impl callsite::Callsite for $ty {
39 fn set_interest(&self, _: Interest) {}
40 fn metadata(&self) -> &'static Metadata<'static> {
41 &$meta
42 }
43 }
44 };
45}
46log_cs!(
47 tracing_core::Level::DEBUG,
48 DEBUG_CS,
49 DEBUG_META,
50 DEBUG_FIELDS,
51 DebugCallsite
52);
53log_cs!(
54 tracing_core::Level::INFO,
55 INFO_CS,
56 INFO_META,
57 INFO_FIELDS,
58 InfoCallsite
59);
60log_cs!(
61 tracing_core::Level::WARN,
62 WARN_CS,
63 WARN_META,
64 WARN_FIELDS,
65 WarnCallsite
66);
67log_cs!(
68 tracing_core::Level::ERROR,
69 ERROR_CS,
70 ERROR_META,
71 ERROR_FIELDS,
72 ErrorCallsite
73);
74
75#[derive(Clone, Copy)]
76pub enum Module {
77 Ggml,
78 LlamaCpp,
79}
80
81impl Module {
82 const fn name(self) -> &'static str {
83 match self {
84 Self::Ggml => "ggml",
85 Self::LlamaCpp => "llama.cpp",
86 }
87 }
88}
89
90fn meta_for_level(
91 level: llama_cpp_bindings_sys::ggml_log_level,
92) -> (&'static Metadata<'static>, &'static OverridableFields) {
93 match level {
94 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG => (&DEBUG_META, &DEBUG_FIELDS),
95 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO => (&INFO_META, &INFO_FIELDS),
96 llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN => (&WARN_META, &WARN_FIELDS),
97 llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR => (&ERROR_META, &ERROR_FIELDS),
98 _ => {
99 unreachable!("Illegal log level to be called here")
100 }
101 }
102}
103
104pub struct State {
105 pub options: LogOptions,
106 module: Module,
107 buffered: std::sync::Mutex<Option<(llama_cpp_bindings_sys::ggml_log_level, String)>>,
108 previous_level: std::sync::atomic::AtomicI32,
109 is_buffering: std::sync::atomic::AtomicBool,
110}
111
112impl State {
113 pub fn new(module: Module, options: LogOptions) -> Self {
114 Self {
115 options,
116 module,
117 buffered: std::sync::Mutex::default(),
118 previous_level: std::sync::atomic::AtomicI32::default(),
119 is_buffering: std::sync::atomic::AtomicBool::default(),
120 }
121 }
122
123 fn generate_log(target: Module, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) {
124 let (module, text) = text
130 .char_indices()
131 .take_while(|(_, ch)| ch.is_ascii_lowercase() || *ch == '_')
132 .last()
133 .and_then(|(pos, _)| {
134 let next_two = text.get(pos + 1..pos + 3);
135 if next_two == Some(": ") {
136 let (sub_module, text) = text.split_at(pos + 1);
137 let text = text.split_at(2).1;
138 Some((Some(format!("{}::{sub_module}", target.name())), text))
139 } else {
140 None
141 }
142 })
143 .unwrap_or((None, text));
144
145 let (meta, fields) = meta_for_level(level);
146
147 tracing::dispatcher::get_default(|dispatcher| {
148 dispatcher.event(&tracing::Event::new(
149 meta,
150 &meta.fields().value_set(&[
151 (&fields.message, Some(&text as &dyn tracing::field::Value)),
152 (
153 &fields.target,
154 module
155 .as_ref()
156 .map(|module_name| module_name as &dyn tracing::field::Value),
157 ),
158 ]),
159 ));
160 });
161 }
162
163 pub fn cont_buffered_log(&self, text: &str) {
165 let mut lock = self.buffered.lock().unwrap();
166
167 if let Some((previous_log_level, mut buffer)) = lock.take() {
168 buffer.push_str(text);
169 if buffer.ends_with('\n') {
170 self.is_buffering
171 .store(false, std::sync::atomic::Ordering::Release);
172 Self::generate_log(self.module, previous_log_level, buffer.as_str());
173 } else {
174 *lock = Some((previous_log_level, buffer));
175 }
176 } else {
177 let level = self
178 .previous_level
179 .load(std::sync::atomic::Ordering::Acquire)
180 as llama_cpp_bindings_sys::ggml_log_level;
181 tracing::warn!(
182 inferred_level = level,
183 text = text,
184 origin = "crate",
185 "llama.cpp sent out a CONT log without any previously buffered message"
186 );
187 *lock = Some((level, text.to_string()));
188 }
189 }
190
191 pub fn buffer_non_cont(&self, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) {
193 debug_assert!(!text.ends_with('\n'));
194 debug_assert_ne!(level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT);
195
196 if let Some((previous_log_level, buffer)) = self
197 .buffered
198 .lock()
199 .unwrap()
200 .replace((level, text.to_string()))
201 {
202 tracing::warn!(
203 level = previous_log_level,
204 text = &buffer,
205 origin = "crate",
206 "Message buffered unnecessarily due to missing newline and not followed by a CONT"
207 );
208 Self::generate_log(self.module, previous_log_level, buffer.as_str());
209 }
210
211 self.is_buffering
212 .store(true, std::sync::atomic::Ordering::Release);
213 self.previous_level
214 .store(level as i32, std::sync::atomic::Ordering::Release);
215 }
216
217 pub fn emit_non_cont_line(&self, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) {
219 debug_assert!(text.ends_with('\n'));
220 debug_assert_ne!(level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT);
221
222 if self
223 .is_buffering
224 .swap(false, std::sync::atomic::Ordering::Acquire)
225 && let Some((buf_level, buf_text)) = self.buffered.lock().unwrap().take()
226 {
227 tracing::warn!(
229 level = buf_level,
230 text = buf_text,
231 origin = "crate",
232 "llama.cpp message buffered spuriously due to missing \\n and being followed by a non-CONT message!"
233 );
234 Self::generate_log(self.module, buf_level, buf_text.as_str());
235 }
236
237 self.previous_level
238 .store(level as i32, std::sync::atomic::Ordering::Release);
239
240 let (text, newline) = text.split_at(text.len() - 1);
241 debug_assert_eq!(newline, "\n");
242
243 match level {
244 llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE => {
245 tracing::info!(no_log_level = true, text);
246 }
247 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG
248 | llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO
249 | llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN
250 | llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR => {
251 Self::generate_log(self.module, level, text)
252 }
253 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT => unreachable!(),
254 _ => {
255 tracing::warn!(
256 level = level,
257 text = text,
258 origin = "crate",
259 "Unknown llama.cpp log level"
260 );
261 }
262 }
263 }
264
265 pub fn update_previous_level_for_disabled_log(
266 &self,
267 level: llama_cpp_bindings_sys::ggml_log_level,
268 ) {
269 if level != llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT {
270 self.previous_level
271 .store(level as i32, std::sync::atomic::Ordering::Release);
272 }
273 }
274
275 pub fn is_enabled_for_level(&self, level: llama_cpp_bindings_sys::ggml_log_level) -> bool {
277 let level = if level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT {
279 self.previous_level
280 .load(std::sync::atomic::Ordering::Relaxed)
281 as llama_cpp_bindings_sys::ggml_log_level
282 } else {
283 level
284 };
285 let (meta, _) = meta_for_level(level);
286 tracing::dispatcher::get_default(|dispatcher| dispatcher.enabled(meta))
287 }
288}
289
290pub static LLAMA_STATE: OnceLock<Box<State>> = OnceLock::new();
291pub static GGML_STATE: OnceLock<Box<State>> = OnceLock::new();
292
293extern "C" fn logs_to_trace(
294 level: llama_cpp_bindings_sys::ggml_log_level,
295 text: *const ::std::os::raw::c_char,
296 data: *mut ::std::os::raw::c_void,
297) {
298 use std::borrow::Borrow;
303
304 let log_state = unsafe { &*(data as *const State) };
305
306 if log_state.options.disabled {
307 return;
308 }
309
310 if !log_state.is_enabled_for_level(level) {
312 log_state.update_previous_level_for_disabled_log(level);
313
314 return;
315 }
316
317 let text = unsafe { std::ffi::CStr::from_ptr(text) };
318 let text = text.to_string_lossy();
319 let text: &str = text.borrow();
320
321 if level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT {
327 log_state.cont_buffered_log(text);
328 } else if text.ends_with('\n') {
329 log_state.emit_non_cont_line(level, text);
330 } else {
331 log_state.buffer_non_cont(level, text);
332 }
333}
334
335pub fn send_logs_to_tracing(options: LogOptions) {
337 let llama_heap_state = Box::as_ref(
342 LLAMA_STATE.get_or_init(|| Box::new(State::new(Module::LlamaCpp, options.clone()))),
343 ) as *const _;
344 let ggml_heap_state =
345 Box::as_ref(GGML_STATE.get_or_init(|| Box::new(State::new(Module::Ggml, options))))
346 as *const _;
347
348 unsafe {
349 llama_cpp_bindings_sys::llama_log_set(Some(logs_to_trace), llama_heap_state as *mut _);
351 llama_cpp_bindings_sys::ggml_log_set(Some(logs_to_trace), ggml_heap_state as *mut _);
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use std::sync::{Arc, Mutex};
358
359 use tracing_subscriber::util::SubscriberInitExt;
360
361 use super::{Module, State, logs_to_trace};
362 use crate::log_options::LogOptions;
363
364 #[test]
365 fn module_name_ggml() {
366 assert_eq!(Module::Ggml.name(), "ggml");
367 }
368
369 #[test]
370 fn module_name_llama_cpp() {
371 assert_eq!(Module::LlamaCpp.name(), "llama.cpp");
372 }
373
374 #[test]
375 fn state_new_creates_empty_buffer() {
376 let state = State::new(Module::LlamaCpp, LogOptions::default());
377 let buffer = state.buffered.lock().unwrap_or_else(|err| err.into_inner());
378
379 assert!(buffer.is_none());
380 assert!(!state.options.disabled);
381 }
382
383 #[test]
384 fn update_previous_level_for_disabled_log_stores_level() {
385 let state = State::new(Module::LlamaCpp, LogOptions::default());
386
387 state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN);
388
389 let stored = state
390 .previous_level
391 .load(std::sync::atomic::Ordering::Relaxed);
392
393 assert_eq!(stored, llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN as i32);
394 }
395
396 #[test]
397 fn update_previous_level_ignores_cont() {
398 let state = State::new(Module::LlamaCpp, LogOptions::default());
399
400 state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR);
401 state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT);
402
403 let stored = state
404 .previous_level
405 .load(std::sync::atomic::Ordering::Relaxed);
406
407 assert_eq!(stored, llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR as i32);
408 }
409
410 #[test]
411 fn buffer_non_cont_sets_buffering_flag() {
412 let state = State::new(Module::LlamaCpp, LogOptions::default());
413
414 state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "partial");
415
416 assert!(
417 state
418 .is_buffering
419 .load(std::sync::atomic::Ordering::Relaxed)
420 );
421
422 let buffer = state.buffered.lock().unwrap_or_else(|err| err.into_inner());
423
424 assert!(buffer.is_some());
425 let (level, text) = buffer.as_ref().unwrap();
426
427 assert_eq!(*level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO);
428 assert_eq!(text, "partial");
429 }
430
431 #[test]
432 fn cont_buffered_log_appends_to_existing_buffer() {
433 let state = State::new(Module::LlamaCpp, LogOptions::default());
434
435 state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "hello ");
436
437 state.cont_buffered_log("world");
438
439 let buffer = state.buffered.lock().unwrap_or_else(|err| err.into_inner());
440
441 assert!(buffer.is_some());
442 let (_, text) = buffer.as_ref().unwrap();
443
444 assert_eq!(text, "hello world");
445 }
446
447 struct Logger {
448 #[allow(unused)]
449 guard: tracing::subscriber::DefaultGuard,
450 logs: Arc<Mutex<Vec<String>>>,
451 }
452
453 #[derive(Clone)]
454 struct VecWriter(Arc<Mutex<Vec<String>>>);
455
456 impl std::io::Write for VecWriter {
457 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
458 let log_line = String::from_utf8(buf.to_vec()).map_err(|utf8_error| {
459 std::io::Error::new(std::io::ErrorKind::InvalidData, utf8_error)
460 })?;
461 self.0.lock().unwrap().push(log_line);
462
463 Ok(buf.len())
464 }
465
466 fn flush(&mut self) -> std::io::Result<()> {
467 Ok(())
468 }
469 }
470
471 fn create_logger(max_level: tracing::Level) -> Logger {
472 let logs = Arc::new(Mutex::new(vec![]));
473 let writer = VecWriter(logs.clone());
474
475 Logger {
476 guard: tracing_subscriber::fmt()
477 .with_max_level(max_level)
478 .with_ansi(false)
479 .without_time()
480 .with_file(false)
481 .with_line_number(false)
482 .with_level(false)
483 .with_target(false)
484 .with_writer(move || writer.clone())
485 .finish()
486 .set_default(),
487 logs,
488 }
489 }
490
491 #[test]
492 fn cont_disabled_log() {
493 let logger = create_logger(tracing::Level::INFO);
494 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
495 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
496
497 logs_to_trace(
498 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG,
499 c"Hello ".as_ptr(),
500 log_ptr,
501 );
502 logs_to_trace(
503 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
504 c"world\n".as_ptr(),
505 log_ptr,
506 );
507
508 assert!(logger.logs.lock().unwrap().is_empty());
509
510 logs_to_trace(
511 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG,
512 c"Hello ".as_ptr(),
513 log_ptr,
514 );
515 logs_to_trace(
516 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
517 c"world".as_ptr(),
518 log_ptr,
519 );
520 logs_to_trace(
521 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
522 c"\n".as_ptr(),
523 log_ptr,
524 );
525 }
526
527 #[test]
528 fn cont_enabled_log() {
529 let logger = create_logger(tracing::Level::INFO);
530 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
531 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
532
533 logs_to_trace(
534 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
535 c"Hello ".as_ptr(),
536 log_ptr,
537 );
538 logs_to_trace(
539 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
540 c"world\n".as_ptr(),
541 log_ptr,
542 );
543
544 assert_eq!(*logger.logs.lock().unwrap(), vec!["Hello world\n\n"]);
546 }
547
548 #[test]
549 fn disabled_logs_are_suppressed() {
550 let logger = create_logger(tracing::Level::DEBUG);
551 let disabled_options = LogOptions::default().with_logs_enabled(false);
552 let mut log_state = Box::new(State::new(Module::LlamaCpp, disabled_options));
553 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
554
555 logs_to_trace(
556 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
557 c"Should not appear\n".as_ptr(),
558 log_ptr,
559 );
560 logs_to_trace(
561 llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR,
562 c"Also suppressed\n".as_ptr(),
563 log_ptr,
564 );
565
566 assert!(logger.logs.lock().unwrap().is_empty());
567 }
568
569 #[test]
570 fn info_level_log_emitted() {
571 let logger = create_logger(tracing::Level::INFO);
572 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
573 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
574
575 logs_to_trace(
576 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
577 c"info message\n".as_ptr(),
578 log_ptr,
579 );
580
581 let logs = logger.logs.lock().unwrap();
582
583 assert_eq!(logs.len(), 1);
584 assert!(logs[0].contains("info message"));
585 }
586
587 #[test]
588 fn warn_level_log_emitted() {
589 let logger = create_logger(tracing::Level::WARN);
590 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
591 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
592
593 logs_to_trace(
594 llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN,
595 c"warning message\n".as_ptr(),
596 log_ptr,
597 );
598
599 let logs = logger.logs.lock().unwrap();
600
601 assert_eq!(logs.len(), 1);
602 assert!(logs[0].contains("warning message"));
603 }
604
605 #[test]
606 fn error_level_log_emitted() {
607 let logger = create_logger(tracing::Level::ERROR);
608 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
609 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
610
611 logs_to_trace(
612 llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR,
613 c"error message\n".as_ptr(),
614 log_ptr,
615 );
616
617 let logs = logger.logs.lock().unwrap();
618
619 assert_eq!(logs.len(), 1);
620 assert!(logs[0].contains("error message"));
621 }
622
623 #[test]
624 fn debug_level_log_emitted_when_enabled() {
625 let logger = create_logger(tracing::Level::DEBUG);
626 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
627 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
628
629 logs_to_trace(
630 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG,
631 c"debug message\n".as_ptr(),
632 log_ptr,
633 );
634
635 let logs = logger.logs.lock().unwrap();
636
637 assert_eq!(logs.len(), 1);
638 assert!(logs[0].contains("debug message"));
639 }
640
641 #[test]
642 fn submodule_extraction_from_log_text() {
643 let logger = create_logger(tracing::Level::INFO);
644 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
645 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
646
647 logs_to_trace(
648 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
649 c"sampling: initialized\n".as_ptr(),
650 log_ptr,
651 );
652
653 let logs = logger.logs.lock().unwrap();
654
655 assert_eq!(logs.len(), 1);
656 assert!(logs[0].contains("initialized"));
657 }
658
659 #[test]
660 fn multi_part_cont_log() {
661 let logger = create_logger(tracing::Level::INFO);
662 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
663 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
664
665 logs_to_trace(
666 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
667 c"part1 ".as_ptr(),
668 log_ptr,
669 );
670 logs_to_trace(
671 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
672 c"part2 ".as_ptr(),
673 log_ptr,
674 );
675 logs_to_trace(
676 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
677 c"part3\n".as_ptr(),
678 log_ptr,
679 );
680
681 let logs = logger.logs.lock().unwrap();
682
683 assert_eq!(logs.len(), 1);
684 assert!(logs[0].contains("part1 part2 part3"));
685 }
686}