1use crate::log_options::LogOptions;
2use std::sync::OnceLock;
3use tracing_core::{Interest, Kind, Metadata, callsite, field, identify_callsite};
4
5static FIELD_NAMES: &[&str] = &["message", "module"];
6
7struct OverridableFields {
8 message: tracing::field::Field,
9 target: tracing::field::Field,
10}
11
12macro_rules! log_cs {
13 ($level:expr, $cs:ident, $meta:ident, $fields:ident, $ty:ident) => {
14 struct $ty;
15 static $cs: $ty = $ty;
16 static $meta: Metadata<'static> = Metadata::new(
17 "log event",
18 "llama-cpp-bindings",
19 $level,
20 ::core::option::Option::None,
21 ::core::option::Option::None,
22 ::core::option::Option::None,
23 field::FieldSet::new(FIELD_NAMES, identify_callsite!(&$cs)),
24 Kind::EVENT,
25 );
26 static $fields: std::sync::LazyLock<OverridableFields> = std::sync::LazyLock::new(|| {
27 let fields = $meta.fields();
28 OverridableFields {
29 message: fields
30 .field("message")
31 .expect("message field defined in FIELD_NAMES"),
32 target: fields
33 .field("module")
34 .expect("module field defined in FIELD_NAMES"),
35 }
36 });
37
38 impl callsite::Callsite for $ty {
39 fn set_interest(&self, _: Interest) {}
40 fn metadata(&self) -> &'static Metadata<'static> {
41 &$meta
42 }
43 }
44 };
45}
46log_cs!(
47 tracing_core::Level::DEBUG,
48 DEBUG_CS,
49 DEBUG_META,
50 DEBUG_FIELDS,
51 DebugCallsite
52);
53log_cs!(
54 tracing_core::Level::INFO,
55 INFO_CS,
56 INFO_META,
57 INFO_FIELDS,
58 InfoCallsite
59);
60log_cs!(
61 tracing_core::Level::WARN,
62 WARN_CS,
63 WARN_META,
64 WARN_FIELDS,
65 WarnCallsite
66);
67log_cs!(
68 tracing_core::Level::ERROR,
69 ERROR_CS,
70 ERROR_META,
71 ERROR_FIELDS,
72 ErrorCallsite
73);
74
75#[derive(Clone, Copy)]
76pub enum Module {
77 Ggml,
78 LlamaCpp,
79}
80
81impl Module {
82 const fn name(self) -> &'static str {
83 match self {
84 Self::Ggml => "ggml",
85 Self::LlamaCpp => "llama.cpp",
86 }
87 }
88}
89
90fn meta_for_level(
91 level: llama_cpp_bindings_sys::ggml_log_level,
92) -> (&'static Metadata<'static>, &'static OverridableFields) {
93 match level {
94 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG => (&DEBUG_META, &DEBUG_FIELDS),
95 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO => (&INFO_META, &INFO_FIELDS),
96 llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN => (&WARN_META, &WARN_FIELDS),
97 llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR => (&ERROR_META, &ERROR_FIELDS),
98 _ => {
99 unreachable!("Illegal log level to be called here")
100 }
101 }
102}
103
104pub struct State {
105 pub options: LogOptions,
106 module: Module,
107 buffered: std::sync::Mutex<Option<(llama_cpp_bindings_sys::ggml_log_level, String)>>,
108 previous_level: std::sync::atomic::AtomicI32,
109 is_buffering: std::sync::atomic::AtomicBool,
110}
111
112impl State {
113 pub fn new(module: Module, options: LogOptions) -> Self {
114 Self {
115 options,
116 module,
117 buffered: std::sync::Mutex::default(),
118 previous_level: std::sync::atomic::AtomicI32::default(),
119 is_buffering: std::sync::atomic::AtomicBool::default(),
120 }
121 }
122
123 fn generate_log(target: Module, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) {
124 let (module, text) = text
130 .char_indices()
131 .take_while(|(_, c)| c.is_ascii_lowercase() || *c == '_')
132 .last()
133 .and_then(|(pos, _)| {
134 let next_two = text.get(pos + 1..pos + 3);
135 if next_two == Some(": ") {
136 let (sub_module, text) = text.split_at(pos + 1);
137 let text = text.split_at(2).1;
138 Some((Some(format!("{}::{sub_module}", target.name())), text))
139 } else {
140 None
141 }
142 })
143 .unwrap_or((None, text));
144
145 let (meta, fields) = meta_for_level(level);
146
147 tracing::dispatcher::get_default(|dispatcher| {
148 dispatcher.event(&tracing::Event::new(
149 meta,
150 &meta.fields().value_set(&[
151 (&fields.message, Some(&text as &dyn tracing::field::Value)),
152 (
153 &fields.target,
154 module.as_ref().map(|s| s as &dyn tracing::field::Value),
155 ),
156 ]),
157 ));
158 });
159 }
160
161 pub fn cont_buffered_log(&self, text: &str) {
163 let mut lock = self.buffered.lock().unwrap();
164
165 if let Some((previous_log_level, mut buffer)) = lock.take() {
166 buffer.push_str(text);
167 if buffer.ends_with('\n') {
168 self.is_buffering
169 .store(false, std::sync::atomic::Ordering::Release);
170 Self::generate_log(self.module, previous_log_level, buffer.as_str());
171 } else {
172 *lock = Some((previous_log_level, buffer));
173 }
174 } else {
175 let level = self
176 .previous_level
177 .load(std::sync::atomic::Ordering::Acquire)
178 as llama_cpp_bindings_sys::ggml_log_level;
179 tracing::warn!(
180 inferred_level = level,
181 text = text,
182 origin = "crate",
183 "llama.cpp sent out a CONT log without any previously buffered message"
184 );
185 *lock = Some((level, text.to_string()));
186 }
187 }
188
189 pub fn buffer_non_cont(&self, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) {
191 debug_assert!(!text.ends_with('\n'));
192 debug_assert_ne!(level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT);
193
194 if let Some((previous_log_level, buffer)) = self
195 .buffered
196 .lock()
197 .unwrap()
198 .replace((level, text.to_string()))
199 {
200 tracing::warn!(
201 level = previous_log_level,
202 text = &buffer,
203 origin = "crate",
204 "Message buffered unnecessarily due to missing newline and not followed by a CONT"
205 );
206 Self::generate_log(self.module, previous_log_level, buffer.as_str());
207 }
208
209 self.is_buffering
210 .store(true, std::sync::atomic::Ordering::Release);
211 self.previous_level
212 .store(level as i32, std::sync::atomic::Ordering::Release);
213 }
214
215 pub fn emit_non_cont_line(&self, level: llama_cpp_bindings_sys::ggml_log_level, text: &str) {
217 debug_assert!(text.ends_with('\n'));
218 debug_assert_ne!(level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT);
219
220 if self
221 .is_buffering
222 .swap(false, std::sync::atomic::Ordering::Acquire)
223 && let Some((buf_level, buf_text)) = self.buffered.lock().unwrap().take()
224 {
225 tracing::warn!(
227 level = buf_level,
228 text = buf_text,
229 origin = "crate",
230 "llama.cpp message buffered spuriously due to missing \\n and being followed by a non-CONT message!"
231 );
232 Self::generate_log(self.module, buf_level, buf_text.as_str());
233 }
234
235 self.previous_level
236 .store(level as i32, std::sync::atomic::Ordering::Release);
237
238 let (text, newline) = text.split_at(text.len() - 1);
239 debug_assert_eq!(newline, "\n");
240
241 match level {
242 llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE => {
243 tracing::info!(no_log_level = true, text);
244 }
245 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG
246 | llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO
247 | llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN
248 | llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR => {
249 Self::generate_log(self.module, level, text)
250 }
251 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT => unreachable!(),
252 _ => {
253 tracing::warn!(
254 level = level,
255 text = text,
256 origin = "crate",
257 "Unknown llama.cpp log level"
258 );
259 }
260 }
261 }
262
263 pub fn update_previous_level_for_disabled_log(
264 &self,
265 level: llama_cpp_bindings_sys::ggml_log_level,
266 ) {
267 if level != llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT {
268 self.previous_level
269 .store(level as i32, std::sync::atomic::Ordering::Release);
270 }
271 }
272
273 pub fn is_enabled_for_level(&self, level: llama_cpp_bindings_sys::ggml_log_level) -> bool {
275 let level = if level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT {
277 self.previous_level
278 .load(std::sync::atomic::Ordering::Relaxed)
279 as llama_cpp_bindings_sys::ggml_log_level
280 } else {
281 level
282 };
283 let (meta, _) = meta_for_level(level);
284 tracing::dispatcher::get_default(|dispatcher| dispatcher.enabled(meta))
285 }
286}
287
288pub static LLAMA_STATE: OnceLock<Box<State>> = OnceLock::new();
289pub static GGML_STATE: OnceLock<Box<State>> = OnceLock::new();
290
291extern "C" fn logs_to_trace(
292 level: llama_cpp_bindings_sys::ggml_log_level,
293 text: *const ::std::os::raw::c_char,
294 data: *mut ::std::os::raw::c_void,
295) {
296 use std::borrow::Borrow;
301
302 let log_state = unsafe { &*(data as *const State) };
303
304 if log_state.options.disabled {
305 return;
306 }
307
308 if !log_state.is_enabled_for_level(level) {
310 log_state.update_previous_level_for_disabled_log(level);
311
312 return;
313 }
314
315 let text = unsafe { std::ffi::CStr::from_ptr(text) };
316 let text = text.to_string_lossy();
317 let text: &str = text.borrow();
318
319 if level == llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT {
325 log_state.cont_buffered_log(text);
326 } else if text.ends_with('\n') {
327 log_state.emit_non_cont_line(level, text);
328 } else {
329 log_state.buffer_non_cont(level, text);
330 }
331}
332
333pub fn send_logs_to_tracing(options: LogOptions) {
335 let llama_heap_state = Box::as_ref(
340 LLAMA_STATE.get_or_init(|| Box::new(State::new(Module::LlamaCpp, options.clone()))),
341 ) as *const _;
342 let ggml_heap_state =
343 Box::as_ref(GGML_STATE.get_or_init(|| Box::new(State::new(Module::Ggml, options))))
344 as *const _;
345
346 unsafe {
347 llama_cpp_bindings_sys::llama_log_set(Some(logs_to_trace), llama_heap_state as *mut _);
349 llama_cpp_bindings_sys::ggml_log_set(Some(logs_to_trace), ggml_heap_state as *mut _);
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use std::sync::{Arc, Mutex};
356
357 use tracing_subscriber::util::SubscriberInitExt;
358
359 use super::{Module, State, logs_to_trace};
360 use crate::log_options::LogOptions;
361
362 #[test]
363 fn module_name_ggml() {
364 assert_eq!(Module::Ggml.name(), "ggml");
365 }
366
367 #[test]
368 fn module_name_llama_cpp() {
369 assert_eq!(Module::LlamaCpp.name(), "llama.cpp");
370 }
371
372 #[test]
373 fn state_new_creates_empty_buffer() {
374 let state = State::new(Module::LlamaCpp, LogOptions::default());
375 let buffer = state.buffered.lock().unwrap_or_else(|err| err.into_inner());
376
377 assert!(buffer.is_none());
378 assert!(!state.options.disabled);
379 }
380
381 #[test]
382 fn update_previous_level_for_disabled_log_stores_level() {
383 let state = State::new(Module::LlamaCpp, LogOptions::default());
384
385 state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN);
386
387 let stored = state
388 .previous_level
389 .load(std::sync::atomic::Ordering::Relaxed);
390
391 assert_eq!(stored, llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN as i32);
392 }
393
394 #[test]
395 fn update_previous_level_ignores_cont() {
396 let state = State::new(Module::LlamaCpp, LogOptions::default());
397
398 state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR);
399 state.update_previous_level_for_disabled_log(llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT);
400
401 let stored = state
402 .previous_level
403 .load(std::sync::atomic::Ordering::Relaxed);
404
405 assert_eq!(stored, llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR as i32);
406 }
407
408 #[test]
409 fn buffer_non_cont_sets_buffering_flag() {
410 let state = State::new(Module::LlamaCpp, LogOptions::default());
411
412 state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "partial");
413
414 assert!(
415 state
416 .is_buffering
417 .load(std::sync::atomic::Ordering::Relaxed)
418 );
419
420 let buffer = state.buffered.lock().unwrap_or_else(|err| err.into_inner());
421
422 assert!(buffer.is_some());
423 let (level, text) = buffer.as_ref().unwrap();
424
425 assert_eq!(*level, llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO);
426 assert_eq!(text, "partial");
427 }
428
429 #[test]
430 fn cont_buffered_log_appends_to_existing_buffer() {
431 let state = State::new(Module::LlamaCpp, LogOptions::default());
432
433 state.buffer_non_cont(llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO, "hello ");
434
435 state.cont_buffered_log("world");
436
437 let buffer = state.buffered.lock().unwrap_or_else(|err| err.into_inner());
438
439 assert!(buffer.is_some());
440 let (_, text) = buffer.as_ref().unwrap();
441
442 assert_eq!(text, "hello world");
443 }
444
445 struct Logger {
446 #[allow(unused)]
447 guard: tracing::subscriber::DefaultGuard,
448 logs: Arc<Mutex<Vec<String>>>,
449 }
450
451 #[derive(Clone)]
452 struct VecWriter(Arc<Mutex<Vec<String>>>);
453
454 impl std::io::Write for VecWriter {
455 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
456 let log_line = String::from_utf8(buf.to_vec()).map_err(|utf8_error| {
457 std::io::Error::new(std::io::ErrorKind::InvalidData, utf8_error)
458 })?;
459 self.0.lock().unwrap().push(log_line);
460
461 Ok(buf.len())
462 }
463
464 fn flush(&mut self) -> std::io::Result<()> {
465 Ok(())
466 }
467 }
468
469 fn create_logger(max_level: tracing::Level) -> Logger {
470 let logs = Arc::new(Mutex::new(vec![]));
471 let writer = VecWriter(logs.clone());
472
473 Logger {
474 guard: tracing_subscriber::fmt()
475 .with_max_level(max_level)
476 .with_ansi(false)
477 .without_time()
478 .with_file(false)
479 .with_line_number(false)
480 .with_level(false)
481 .with_target(false)
482 .with_writer(move || writer.clone())
483 .finish()
484 .set_default(),
485 logs,
486 }
487 }
488
489 #[test]
490 fn cont_disabled_log() {
491 let logger = create_logger(tracing::Level::INFO);
492 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
493 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
494
495 logs_to_trace(
496 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG,
497 c"Hello ".as_ptr(),
498 log_ptr,
499 );
500 logs_to_trace(
501 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
502 c"world\n".as_ptr(),
503 log_ptr,
504 );
505
506 assert!(logger.logs.lock().unwrap().is_empty());
507
508 logs_to_trace(
509 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG,
510 c"Hello ".as_ptr(),
511 log_ptr,
512 );
513 logs_to_trace(
514 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
515 c"world".as_ptr(),
516 log_ptr,
517 );
518 logs_to_trace(
519 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
520 c"\n".as_ptr(),
521 log_ptr,
522 );
523 }
524
525 #[test]
526 fn cont_enabled_log() {
527 let logger = create_logger(tracing::Level::INFO);
528 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
529 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
530
531 logs_to_trace(
532 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
533 c"Hello ".as_ptr(),
534 log_ptr,
535 );
536 logs_to_trace(
537 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
538 c"world\n".as_ptr(),
539 log_ptr,
540 );
541
542 assert_eq!(*logger.logs.lock().unwrap(), vec!["Hello world\n\n"]);
544 }
545
546 #[test]
547 fn disabled_logs_are_suppressed() {
548 let logger = create_logger(tracing::Level::DEBUG);
549 let disabled_options = LogOptions::default().with_logs_enabled(false);
550 let mut log_state = Box::new(State::new(Module::LlamaCpp, disabled_options));
551 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
552
553 logs_to_trace(
554 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
555 c"Should not appear\n".as_ptr(),
556 log_ptr,
557 );
558 logs_to_trace(
559 llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR,
560 c"Also suppressed\n".as_ptr(),
561 log_ptr,
562 );
563
564 assert!(logger.logs.lock().unwrap().is_empty());
565 }
566
567 #[test]
568 fn info_level_log_emitted() {
569 let logger = create_logger(tracing::Level::INFO);
570 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
571 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
572
573 logs_to_trace(
574 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
575 c"info message\n".as_ptr(),
576 log_ptr,
577 );
578
579 let logs = logger.logs.lock().unwrap();
580
581 assert_eq!(logs.len(), 1);
582 assert!(logs[0].contains("info message"));
583 }
584
585 #[test]
586 fn warn_level_log_emitted() {
587 let logger = create_logger(tracing::Level::WARN);
588 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
589 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
590
591 logs_to_trace(
592 llama_cpp_bindings_sys::GGML_LOG_LEVEL_WARN,
593 c"warning message\n".as_ptr(),
594 log_ptr,
595 );
596
597 let logs = logger.logs.lock().unwrap();
598
599 assert_eq!(logs.len(), 1);
600 assert!(logs[0].contains("warning message"));
601 }
602
603 #[test]
604 fn error_level_log_emitted() {
605 let logger = create_logger(tracing::Level::ERROR);
606 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
607 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
608
609 logs_to_trace(
610 llama_cpp_bindings_sys::GGML_LOG_LEVEL_ERROR,
611 c"error message\n".as_ptr(),
612 log_ptr,
613 );
614
615 let logs = logger.logs.lock().unwrap();
616
617 assert_eq!(logs.len(), 1);
618 assert!(logs[0].contains("error message"));
619 }
620
621 #[test]
622 fn debug_level_log_emitted_when_enabled() {
623 let logger = create_logger(tracing::Level::DEBUG);
624 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
625 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
626
627 logs_to_trace(
628 llama_cpp_bindings_sys::GGML_LOG_LEVEL_DEBUG,
629 c"debug message\n".as_ptr(),
630 log_ptr,
631 );
632
633 let logs = logger.logs.lock().unwrap();
634
635 assert_eq!(logs.len(), 1);
636 assert!(logs[0].contains("debug message"));
637 }
638
639 #[test]
640 fn submodule_extraction_from_log_text() {
641 let logger = create_logger(tracing::Level::INFO);
642 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
643 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
644
645 logs_to_trace(
646 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
647 c"sampling: initialized\n".as_ptr(),
648 log_ptr,
649 );
650
651 let logs = logger.logs.lock().unwrap();
652
653 assert_eq!(logs.len(), 1);
654 assert!(logs[0].contains("initialized"));
655 }
656
657 #[test]
658 fn multi_part_cont_log() {
659 let logger = create_logger(tracing::Level::INFO);
660 let mut log_state = Box::new(State::new(Module::LlamaCpp, LogOptions::default()));
661 let log_ptr = log_state.as_mut() as *mut State as *mut std::os::raw::c_void;
662
663 logs_to_trace(
664 llama_cpp_bindings_sys::GGML_LOG_LEVEL_INFO,
665 c"part1 ".as_ptr(),
666 log_ptr,
667 );
668 logs_to_trace(
669 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
670 c"part2 ".as_ptr(),
671 log_ptr,
672 );
673 logs_to_trace(
674 llama_cpp_bindings_sys::GGML_LOG_LEVEL_CONT,
675 c"part3\n".as_ptr(),
676 log_ptr,
677 );
678
679 let logs = logger.logs.lock().unwrap();
680
681 assert_eq!(logs.len(), 1);
682 assert!(logs[0].contains("part1 part2 part3"));
683 }
684}