atlas_program_log/
logger.rs

1use core::{
2    cmp::min, mem::MaybeUninit, ops::Deref, ptr::copy_nonoverlapping, slice::from_raw_parts,
3};
4#[cfg(any(target_os = "atlas", target_arch = "bpf"))]
5use atlas_define_syscall::definitions::{
6    sol_log_, sol_memcpy_, sol_memset_, sol_remaining_compute_units,
7};
8
9/// Bytes for a truncated `str` log message.
10const TRUNCATED_SLICE: [u8; 3] = [b'.', b'.', b'.'];
11
12/// Byte representing a truncated log.
13const TRUNCATED: u8 = b'@';
14
15/// An uninitialized byte.
16const UNINIT_BYTE: MaybeUninit<u8> = MaybeUninit::uninit();
17
18/// Logger to efficiently format log messages.
19///
20/// The logger is a fixed size buffer that can be used to format log messages
21/// before sending them to the log output. Any type that implements the `Log`
22/// trait can be appended to the logger.
23pub struct Logger<const BUFFER: usize> {
24    // Byte buffer to store the log message.
25    buffer: [MaybeUninit<u8>; BUFFER],
26
27    // Length of the log message.
28    len: usize,
29}
30
31impl<const BUFFER: usize> Default for Logger<BUFFER> {
32    #[inline]
33    fn default() -> Self {
34        Self {
35            buffer: [UNINIT_BYTE; BUFFER],
36            len: 0,
37        }
38    }
39}
40
41impl<const BUFFER: usize> Deref for Logger<BUFFER> {
42    type Target = [u8];
43
44    fn deref(&self) -> &Self::Target {
45        // SAFETY: the slice is created from the buffer up to the length
46        // of the message.
47        unsafe { from_raw_parts(self.buffer.as_ptr() as *const _, self.len) }
48    }
49}
50
51impl<const BUFFER: usize> Logger<BUFFER> {
52    /// Append a value to the logger.
53    #[inline(always)]
54    pub fn append<T: Log>(&mut self, value: T) -> &mut Self {
55        self.append_with_args(value, &[]);
56        self
57    }
58
59    /// Append a value to the logger with formatting arguments.
60    #[inline]
61    pub fn append_with_args<T: Log>(&mut self, value: T, args: &[Argument]) -> &mut Self {
62        if self.is_full() {
63            if BUFFER > 0 {
64                // SAFETY: the buffer is checked to be full.
65                unsafe {
66                    let last = self.buffer.get_unchecked_mut(BUFFER - 1);
67                    last.write(TRUNCATED);
68                }
69            }
70        } else {
71            self.len += value.write_with_args(&mut self.buffer[self.len..], args);
72
73            if self.len > BUFFER {
74                // Indicates that the buffer is full.
75                self.len = BUFFER;
76                // SAFETY: the buffer length is checked to greater than `BUFFER`.
77                unsafe {
78                    let last = self.buffer.get_unchecked_mut(BUFFER - 1);
79                    last.write(TRUNCATED);
80                }
81            }
82        }
83
84        self
85    }
86
87    /// Log the message in the buffer.
88    #[inline(always)]
89    pub fn log(&self) {
90        log_message(self);
91    }
92
93    /// Clear the message buffer.
94    #[inline(always)]
95    pub fn clear(&mut self) {
96        self.len = 0;
97    }
98
99    /// Check whether the log buffer is at the maximum length or not.
100    #[inline(always)]
101    pub fn is_full(&self) -> bool {
102        self.len == BUFFER
103    }
104
105    /// Get the remaining space in the log buffer.
106    #[inline(always)]
107    pub fn remaining(&self) -> usize {
108        BUFFER - self.len
109    }
110}
111
112/// Log a message.
113#[inline(always)]
114pub fn log_message(message: &[u8]) {
115    #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
116    // SAFETY: the message is always a valid pointer to a slice of bytes
117    // and `sol_log_` is a syscall.
118    unsafe {
119        sol_log_(message.as_ptr(), message.len() as u64);
120    }
121    #[cfg(all(not(any(target_os = "atlas", target_arch = "bpf")), feature = "std"))]
122    {
123        let message = core::str::from_utf8(message).unwrap();
124        std::println!("{message}");
125    }
126
127    #[cfg(all(
128        not(any(target_os = "atlas", target_arch = "bpf")),
129        not(feature = "std")
130    ))]
131    core::hint::black_box(message);
132}
133
134/// Remaining CUs.
135#[inline(always)]
136pub fn remaining_compute_units() -> u64 {
137    #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
138    // SAFETY: `sol_remaining_compute_units` is a syscall that returns the remaining compute units.
139    unsafe {
140        sol_remaining_compute_units()
141    }
142    #[cfg(not(any(target_os = "atlas", target_arch = "bpf")))]
143    core::hint::black_box(0u64)
144}
145
146/// Formatting arguments.
147///
148/// Arguments can be used to specify additional formatting options for the log message.
149/// Note that types might not support all arguments.
150#[non_exhaustive]
151pub enum Argument {
152    /// Number of decimal places to display for numbers.
153    ///
154    /// This is only applicable for numeric types.
155    Precision(u8),
156
157    /// Truncate the output at the end when the specified maximum number of characters
158    /// is exceeded.
159    ///
160    /// This is only applicable for `str` types.
161    TruncateEnd(usize),
162
163    /// Truncate the output at the start when the specified maximum number of characters
164    /// is exceeded.
165    ///
166    /// This is only applicable for `str` types.
167    TruncateStart(usize),
168}
169
170/// Trait to specify the log behavior for a type.
171///
172/// # Safety
173///
174/// The implementation must ensure that the value returned by any of the methods correctly
175/// reflects the actual number of bytes written to the buffer. Returning a value greater
176/// than the number of bytes written to the buffer will result in undefined behavior, since
177/// it will lead to reading uninitialized memory from the buffer.
178pub unsafe trait Log {
179    #[inline(always)]
180    fn debug(&self, buffer: &mut [MaybeUninit<u8>]) -> usize {
181        self.debug_with_args(buffer, &[])
182    }
183
184    #[inline(always)]
185    fn debug_with_args(&self, buffer: &mut [MaybeUninit<u8>], args: &[Argument]) -> usize {
186        self.write_with_args(buffer, args)
187    }
188
189    #[inline(always)]
190    fn write(&self, buffer: &mut [MaybeUninit<u8>]) -> usize {
191        self.write_with_args(buffer, &[])
192    }
193
194    fn write_with_args(&self, buffer: &mut [MaybeUninit<u8>], parameters: &[Argument]) -> usize;
195}
196
197/// Implement the log trait for unsigned integer types.
198macro_rules! impl_log_for_unsigned_integer {
199    ( $type:tt ) => {
200        unsafe impl Log for $type {
201            #[inline]
202            fn write_with_args(&self, buffer: &mut [MaybeUninit<u8>], args: &[Argument]) -> usize {
203                // The maximum number of digits that the type can have.
204                const MAX_DIGITS: usize = $type::MAX.ilog10() as usize + 1;
205
206                if buffer.is_empty() {
207                    return 0;
208                }
209
210                match *self {
211                    // Handle zero as a special case.
212                    0 => {
213                        // SAFETY: the buffer is checked to be non-empty.
214                        unsafe {
215                            buffer.get_unchecked_mut(0).write(b'0');
216                        }
217                        1
218                    }
219                    mut value => {
220                        let mut digits = [UNINIT_BYTE; MAX_DIGITS];
221                        let mut offset = MAX_DIGITS;
222
223                        while value > 0 {
224                            let remainder = value % 10;
225                            value /= 10;
226                            offset -= 1;
227                            // SAFETY: the offset is always within the bounds of the array since
228                            // `offset` is initialized with the maximum number of digits that
229                            // the type can have and decremented on each iteration; `remainder`
230                            // is always less than 10.
231                            unsafe {
232                                digits
233                                    .get_unchecked_mut(offset)
234                                    .write(b'0' + remainder as u8);
235                            }
236                        }
237
238                        let precision = if let Some(Argument::Precision(p)) = args
239                            .iter()
240                            .find(|arg| matches!(arg, Argument::Precision(_)))
241                        {
242                            *p as usize
243                        } else {
244                            0
245                        };
246
247                        let written = MAX_DIGITS - offset;
248                        let length = buffer.len();
249
250                        // Space required with the specified precision. We might need
251                        // to add leading zeros and a decimal point, but this is only
252                        // if the precision is greater than zero.
253                        let required = match precision {
254                            0 => written,
255                            // decimal point
256                            _precision if precision < written => written + 1,
257                            // decimal point + one leading zero
258                            _ => precision + 2,
259                        };
260                        // Determines whether the value will be truncated or not.
261                        let is_truncated = required > length;
262                        // Cap the number of digits to write to the buffer length.
263                        let digits_to_write = min(MAX_DIGITS - offset, length);
264
265                        // SAFETY: the length of both `digits` and `buffer` arrays are guaranteed
266                        // to be within bounds and the `digits_to_write` value is capped to the
267                        // length of the `buffer`.
268                        unsafe {
269                            let source = digits.as_ptr().add(offset);
270                            let ptr = buffer.as_mut_ptr();
271
272                            // Copy the number to the buffer if no precision is specified.
273                            if precision == 0 {
274                                #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
275                                sol_memcpy_(
276                                    ptr as *mut _,
277                                    source as *const _,
278                                    digits_to_write as u64,
279                                );
280                                #[cfg(not(any(target_os = "atlas", target_arch = "bpf")))]
281                                copy_nonoverlapping(source, ptr, digits_to_write);
282                            }
283                            // If padding is needed to satisfy the precision, add leading zeros
284                            // and a decimal point.
285                            else if precision >= digits_to_write {
286                                // Prefix.
287                                (ptr as *mut u8).write(b'0');
288
289                                if length > 2 {
290                                    (ptr.add(1) as *mut u8).write(b'.');
291                                    let padding = min(length - 2, precision - digits_to_write);
292
293                                    // Precision padding.
294                                    #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
295                                    sol_memset_(ptr.add(2) as *mut _, b'0', padding as u64);
296                                    #[cfg(not(any(target_os = "atlas", target_arch = "bpf")))]
297                                    (ptr.add(2) as *mut u8).write_bytes(b'0', padding);
298
299                                    let current = 2 + padding;
300
301                                    // If there is still space, copy (part of) the number.
302                                    if current < length {
303                                        let remaining = min(digits_to_write, length - current);
304
305                                        // Number part.
306                                        #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
307                                        sol_memcpy_(
308                                            ptr.add(current) as *mut _,
309                                            source as *const _,
310                                            remaining as u64,
311                                        );
312                                        #[cfg(not(any(
313                                            target_os = "atlas",
314                                            target_arch = "bpf"
315                                        )))]
316                                        copy_nonoverlapping(source, ptr.add(current), remaining);
317                                    }
318                                }
319                            }
320                            // No padding is needed, calculate the integer and fractional
321                            // parts and add a decimal point.
322                            else {
323                                let integer_part = digits_to_write - precision;
324
325                                // Integer part of the number.
326                                #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
327                                sol_memcpy_(ptr as *mut _, source as *const _, integer_part as u64);
328                                #[cfg(not(any(target_os = "atlas", target_arch = "bpf")))]
329                                copy_nonoverlapping(source, ptr, integer_part);
330
331                                // Decimal point.
332                                (ptr.add(integer_part) as *mut u8).write(b'.');
333                                let current = integer_part + 1;
334
335                                // If there is still space, copy (part of) the remaining.
336                                if current < length {
337                                    let remaining = min(precision, length - current);
338
339                                    // Fractional part of the number.
340                                    #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
341                                    sol_memcpy_(
342                                        ptr.add(current) as *mut _,
343                                        source.add(integer_part) as *const _,
344                                        remaining as u64,
345                                    );
346                                    #[cfg(not(any(target_os = "atlas", target_arch = "bpf")))]
347                                    copy_nonoverlapping(
348                                        source.add(integer_part),
349                                        ptr.add(current),
350                                        remaining,
351                                    );
352                                }
353                            }
354                        }
355
356                        let written = min(required, length);
357
358                        // There might not have been space.
359                        if is_truncated {
360                            // SAFETY: `written` is capped to the length of the buffer and
361                            // the required length (`required` is always greater than zero);
362                            // `buffer` is guaranteed  to have a length of at least 1.
363                            unsafe {
364                                buffer.get_unchecked_mut(written - 1).write(TRUNCATED);
365                            }
366                        }
367
368                        written
369                    }
370                }
371            }
372        }
373    };
374}
375
376// Supported unsigned integer types.
377impl_log_for_unsigned_integer!(u8);
378impl_log_for_unsigned_integer!(u16);
379impl_log_for_unsigned_integer!(u32);
380impl_log_for_unsigned_integer!(u64);
381impl_log_for_unsigned_integer!(usize);
382#[cfg(not(target_arch = "bpf"))]
383impl_log_for_unsigned_integer!(u128);
384
385/// Implement the log trait for the signed integer types.
386macro_rules! impl_log_for_signed {
387    ( $type:tt ) => {
388        unsafe impl Log for $type {
389            #[inline]
390            fn write_with_args(&self, buffer: &mut [MaybeUninit<u8>], args: &[Argument]) -> usize {
391                if buffer.is_empty() {
392                    return 0;
393                }
394
395                match *self {
396                    // Handle zero as a special case.
397                    0 => {
398                        // SAFETY: the buffer is checked to be non-empty.
399                        unsafe {
400                            buffer.get_unchecked_mut(0).write(b'0');
401                        }
402                        1
403                    }
404                    value => {
405                        let mut prefix = 0;
406
407                        if *self < 0 {
408                            if buffer.len() == 1 {
409                                // SAFETY: the buffer is checked to be non-empty.
410                                unsafe {
411                                    buffer.get_unchecked_mut(0).write(TRUNCATED);
412                                }
413                                // There is no space for the number, so just return.
414                                return 1;
415                            }
416
417                            // SAFETY: the buffer is checked to be non-empty.
418                            unsafe {
419                                buffer.get_unchecked_mut(0).write(b'-');
420                            }
421                            prefix += 1;
422                        };
423
424                        prefix
425                            + $type::unsigned_abs(value)
426                                .write_with_args(&mut buffer[prefix..], args)
427                    }
428                }
429            }
430        }
431    };
432}
433
434// Supported signed integer types.
435impl_log_for_signed!(i8);
436impl_log_for_signed!(i16);
437impl_log_for_signed!(i32);
438impl_log_for_signed!(i64);
439impl_log_for_signed!(isize);
440#[cfg(not(target_arch = "bpf"))]
441impl_log_for_signed!(i128);
442
443/// Implement the log trait for the `&str` type.
444unsafe impl Log for &str {
445    #[inline]
446    fn debug_with_args(&self, buffer: &mut [MaybeUninit<u8>], _args: &[Argument]) -> usize {
447        if buffer.is_empty() {
448            return 0;
449        }
450        // SAFETY: the buffer is checked to be non-empty.
451        unsafe {
452            buffer.get_unchecked_mut(0).write(b'"');
453        }
454
455        let mut offset = 1;
456        offset += self.write(&mut buffer[offset..]);
457
458        match buffer.len() - offset {
459            0 => {
460                // SAFETY: the buffer is guaranteed to be within `offset` bounds.
461                unsafe {
462                    buffer.get_unchecked_mut(offset - 1).write(TRUNCATED);
463                }
464            }
465            _ => {
466                // SAFETY: the buffer is guaranteed to be within `offset` bounds.
467                unsafe {
468                    buffer.get_unchecked_mut(offset).write(b'"');
469                }
470                offset += 1;
471            }
472        }
473
474        offset
475    }
476
477    #[inline]
478    fn write_with_args(&self, buffer: &mut [MaybeUninit<u8>], args: &[Argument]) -> usize {
479        // There are 4 different cases to consider:
480        //
481        // 1. No arguments were provided, so the entire string is copied to the buffer if it fits;
482        //    otherwise, the buffer is filled as many characters as possible and the last character
483        //    is set to `TRUNCATED`.
484        //
485        // Then cases only applicable when precision formatting is used:
486        //
487        // 2. The buffer is large enough to hold the entire string: the string is copied to the
488        //    buffer and the length of the string is returned.
489        //
490        // 3. The buffer is smaller than the string, but large enough to hold the prefix and part
491        //    of the string: the prefix and part of the string are copied to the buffer. The length
492        //    returned is `prefix` + number of characters copied.
493        //
494        // 4. The buffer is smaller than the string and the prefix: the buffer is filled with the
495        //    prefix and the last character is set to `TRUNCATED`. The length returned is the length
496        //    of the buffer.
497        //
498        // The length of the message is determined by whether a precision formatting was used or
499        // not, and the length of the buffer.
500
501        let (size, truncate_end) = match args
502            .iter()
503            .find(|arg| matches!(arg, Argument::TruncateEnd(_) | Argument::TruncateStart(_)))
504        {
505            Some(Argument::TruncateEnd(size)) => (*size, Some(true)),
506            Some(Argument::TruncateStart(size)) => (*size, Some(false)),
507            _ => (buffer.len(), None),
508        };
509
510        // Handles the write of the `str` to the buffer.
511        //
512        // - `destination`: pointer to the buffer where the string will be copied. This is always
513        //   the a pointer to the log buffer, but it could de in a different offset depending on
514        //   whether the truncated slice is copied or not.
515        //
516        // - `source`: pointer to the string that will be copied. This could either be a pointer
517        //   to the `str` itself or `TRUNCATE_SLICE`).
518        //
519        // - `length_to_write`: number of characters from `source` that will be copied.
520        //
521        // - `written_truncated_slice_length`: number of characters copied from `TRUNCATED_SLICE`.
522        //   This is used to determine the total number of characters copied to the buffer.
523        //
524        // - `truncated`: indicates whether the `str` was truncated or not. This is used to set
525        //   the last character of the buffer to `TRUNCATED`.
526        let (destination, source, length_to_write, written_truncated_slice_length, truncated) =
527            // No truncate arguments were provided, so the entire `str` is copied to the buffer
528            // if it fits; otherwise indicates that the `str` was truncated.
529            if truncate_end.is_none() {
530                let length = min(size, self.len());
531                (
532                    buffer.as_mut_ptr(),
533                    self.as_ptr(),
534                    length,
535                    0,
536                    length != self.len(),
537                )
538            } else {
539                let max_length = min(size, buffer.len());
540                let ptr = buffer.as_mut_ptr();
541
542                // The buffer is large enough to hold the entire `str`, so no need to use the
543                // truncate args.
544                if max_length >= self.len() {
545                    (ptr, self.as_ptr(), self.len(), 0, false)
546                }
547                // The buffer is large enough to hold the truncated slice and part of the string.
548                // In this case, the characters from the start or end of the string are copied to
549                // the buffer together with the `TRUNCATED_SLICE`.
550                else if max_length > TRUNCATED_SLICE.len() {
551                    // Number of characters that can be copied to the buffer.
552                    let length = max_length - TRUNCATED_SLICE.len();
553                    // SAFETY: the `ptr` is always within `length` bounds.
554                    unsafe {
555                        let (offset, source, destination) = if truncate_end == Some(true) {
556                            (length, self.as_ptr(), ptr)
557                        } else {
558                            (
559                                0,
560                                self.as_ptr().add(self.len() - length),
561                                ptr.add(TRUNCATED_SLICE.len()),
562                            )
563                        };
564                        // Copy the truncated slice to the buffer.
565                        copy_nonoverlapping(
566                            TRUNCATED_SLICE.as_ptr(),
567                            ptr.add(offset) as *mut _,
568                            TRUNCATED_SLICE.len(),
569                        );
570
571                        (destination, source, length, TRUNCATED_SLICE.len(), false)
572                    }
573                }
574                // The buffer is smaller than the `PREFIX`: the buffer is filled with the `PREFIX`
575                // and the last character is set to `TRUNCATED`.
576                else {
577                    (ptr, TRUNCATED_SLICE.as_ptr(), max_length, 0, true)
578                }
579            };
580
581        if length_to_write > 0 {
582            // SAFETY: the `destination` is always within `length_to_write` bounds.
583            unsafe {
584                #[cfg(any(target_os = "atlas", target_arch = "bpf"))]
585                sol_memcpy_(
586                    destination as *mut _,
587                    source as *const _,
588                    length_to_write as u64,
589                );
590                #[cfg(not(any(target_os = "atlas", target_arch = "bpf")))]
591                copy_nonoverlapping(source, destination as *mut _, length_to_write);
592            }
593
594            // There might not have been space for all the value.
595            if truncated {
596                // SAFETY: the `destination` is always within `length_to_write` bounds.
597                unsafe {
598                    let last = buffer.get_unchecked_mut(length_to_write - 1);
599                    last.write(TRUNCATED);
600                }
601            }
602        }
603
604        written_truncated_slice_length + length_to_write
605    }
606}
607
608/// Implement the log trait for the slice type.
609macro_rules! impl_log_for_slice {
610    ( [$type:ident] ) => {
611        unsafe impl<$type> Log for &[$type]
612        where
613            $type: Log
614        {
615            impl_log_for_slice!(@generate_write);
616        }
617    };
618    ( [$type:ident; $size:ident] ) => {
619        unsafe impl<$type, const $size: usize> Log for &[$type; $size]
620        where
621            $type: Log
622        {
623            impl_log_for_slice!(@generate_write);
624        }
625    };
626    ( @generate_write ) => {
627        #[inline]
628        fn write_with_args(&self, buffer: &mut [MaybeUninit<u8>], _args: &[Argument]) -> usize {
629            if buffer.is_empty() {
630                return 0;
631            }
632
633            // Size of the buffer.
634            let length = buffer.len();
635            // SAFETY: the buffer is checked to be non-empty.
636            unsafe {
637                buffer.get_unchecked_mut(0).write(b'[');
638            }
639
640            let mut offset = 1;
641
642            for value in self.iter() {
643                if offset >= length {
644                    // SAFETY: the buffer is checked to be non-empty and the `length`
645                    // represents the buffer length.
646                    unsafe {
647                        buffer.get_unchecked_mut(length - 1).write(TRUNCATED);
648                    }
649                    offset = length;
650                    break;
651                }
652
653                if offset > 1 {
654                    if offset + 2 >= length {
655                        // SAFETY: the buffer is checked to be non-empty and the `length`
656                        // represents the buffer length.
657                        unsafe {
658                            buffer.get_unchecked_mut(length - 1).write(TRUNCATED);
659                        }
660                        offset = length;
661                        break;
662                    } else {
663                        // SAFETY: the buffer is checked to be non-empty and the `offset`
664                        // is smaller than the buffer length.
665                        unsafe {
666                            buffer.get_unchecked_mut(offset).write(b',');
667                            buffer.get_unchecked_mut(offset + 1).write(b' ');
668                        }
669                        offset += 2;
670                    }
671                }
672
673                offset += value.debug(&mut buffer[offset..]);
674            }
675
676            if offset < length {
677                // SAFETY: the buffer is checked to be non-empty and the `offset`
678                // is smaller than the buffer length.
679                unsafe {
680                    buffer.get_unchecked_mut(offset).write(b']');
681                }
682                offset += 1;
683            }
684
685            offset
686        }
687    };
688}
689
690// Supported slice types.
691impl_log_for_slice!([T]);
692impl_log_for_slice!([T; N]);
693
694/// Implement the log trait for the `bool` type.
695unsafe impl Log for bool {
696    #[inline]
697    fn debug_with_args(&self, buffer: &mut [MaybeUninit<u8>], args: &[Argument]) -> usize {
698        let value = if *self { "true" } else { "false" };
699        value.debug_with_args(buffer, args)
700    }
701
702    #[inline]
703    fn write_with_args(&self, buffer: &mut [MaybeUninit<u8>], args: &[Argument]) -> usize {
704        let value = if *self { "true" } else { "false" };
705        value.write_with_args(buffer, args)
706    }
707}