#![no_std]
pub use hyperlight_guest_tracing_macro::*;
#[cfg(feature = "trace")]
pub use trace::{create_trace_record, flush_trace_buffer};
pub const MAX_TRACE_MSG_LEN: usize = 64;
#[derive(Debug, Copy, Clone)]
pub struct TraceRecord {
pub cycles: u64,
pub msg_len: usize,
pub msg: [u8; MAX_TRACE_MSG_LEN],
}
pub mod invariant_tsc {
use core::arch::x86_64::{__cpuid, _rdtsc};
pub fn has_invariant_tsc() -> bool {
let max_extended = unsafe { __cpuid(0x80000000) };
if max_extended.eax < 0x80000007 {
return false;
}
let cpuid_result = unsafe { __cpuid(0x80000007) };
(cpuid_result.edx & (1 << 8)) != 0
}
pub fn read_tsc() -> u64 {
unsafe { _rdtsc() }
}
}
#[cfg(feature = "trace")]
mod trace {
extern crate alloc;
use core::mem::MaybeUninit;
use hyperlight_common::outb::OutBAction;
use spin::Mutex;
use super::{MAX_TRACE_MSG_LEN, TraceRecord, invariant_tsc};
type SendToHostFn = fn(u64, &[TraceRecord]);
static TRACE_BUFFER: Mutex<TraceBuffer<SendToHostFn>> =
Mutex::new(TraceBuffer::new(send_to_host));
const MAX_NO_OF_ENTRIES: usize = 32;
impl From<&str> for TraceRecord {
fn from(mut msg: &str) -> Self {
if msg.len() > MAX_TRACE_MSG_LEN {
msg = &msg[..MAX_TRACE_MSG_LEN];
}
let cycles = invariant_tsc::read_tsc();
TraceRecord {
cycles,
msg: {
let mut arr = [0u8; MAX_TRACE_MSG_LEN];
arr[..msg.len()].copy_from_slice(msg.as_bytes());
arr
},
msg_len: msg.len(),
}
}
}
struct TraceBuffer<F: Fn(u64, &[TraceRecord])> {
entries: [TraceRecord; MAX_NO_OF_ENTRIES],
write_index: usize,
send_to_host: F,
}
impl<F: Fn(u64, &[TraceRecord])> TraceBuffer<F> {
const fn new(f: F) -> Self {
Self {
entries: unsafe { [MaybeUninit::zeroed().assume_init(); MAX_NO_OF_ENTRIES] },
write_index: 0,
send_to_host: f,
}
}
fn push(&mut self, entry: TraceRecord) {
let mut write_index = self.write_index;
self.entries[write_index] = entry;
write_index = (write_index + 1) % MAX_NO_OF_ENTRIES;
self.write_index = write_index;
if write_index == 0 {
(self.send_to_host)(MAX_NO_OF_ENTRIES as u64, &self.entries);
}
}
fn flush(&mut self) {
if self.write_index > 0 {
(self.send_to_host)(self.write_index as u64, &self.entries);
self.write_index = 0; }
}
}
fn send_to_host(len: u64, records: &[TraceRecord]) {
unsafe {
core::arch::asm!("out dx, al",
in("dx") OutBAction::TraceRecord as u16,
in("rax") len,
in("rcx") records.as_ptr() as u64);
}
}
#[inline(always)]
pub fn create_trace_record(msg: &str) {
let entry = TraceRecord::from(msg);
let mut buffer = TRACE_BUFFER.lock();
buffer.push(entry);
}
#[inline(always)]
pub fn flush_trace_buffer() {
let mut buffer = TRACE_BUFFER.lock();
buffer.flush();
}
#[cfg(test)]
mod tests {
use alloc::format;
use super::*;
fn mock_send_to_host(_len: u64, _records: &[TraceRecord]) {}
fn create_test_entry(msg: &str) -> TraceRecord {
let cycles = invariant_tsc::read_tsc();
TraceRecord {
cycles,
msg: {
let mut arr = [0u8; MAX_TRACE_MSG_LEN];
arr[..msg.len()].copy_from_slice(msg.as_bytes());
arr
},
msg_len: msg.len(),
}
}
#[test]
fn test_push_trace_record() {
let mut buffer = TraceBuffer::new(mock_send_to_host);
let msg = "Test message";
let entry = create_test_entry(msg);
buffer.push(entry);
assert_eq!(buffer.write_index, 1);
assert_eq!(buffer.entries[0].msg_len, msg.len());
assert_eq!(&buffer.entries[0].msg[..msg.len()], msg.as_bytes());
assert!(buffer.entries[0].cycles > 0); }
#[test]
fn test_flush_trace_buffer() {
let mut buffer = TraceBuffer::new(mock_send_to_host);
let msg = "Test message";
let entry = create_test_entry(msg);
buffer.push(entry);
assert_eq!(buffer.write_index, 1);
assert_eq!(buffer.entries[0].msg_len, msg.len());
assert_eq!(&buffer.entries[0].msg[..msg.len()], msg.as_bytes());
assert!(buffer.entries[0].cycles > 0);
buffer.flush();
assert_eq!(buffer.write_index, 0);
assert_eq!(buffer.entries[0].msg_len, msg.len());
assert_eq!(&buffer.entries[0].msg[..msg.len()], msg.as_bytes());
assert!(buffer.entries[0].cycles > 0);
}
#[test]
fn test_auto_flush_on_full() {
let mut buffer = TraceBuffer::new(mock_send_to_host);
for i in 0..MAX_NO_OF_ENTRIES {
let msg = format!("Message {}", i);
let entry = create_test_entry(&msg);
buffer.push(entry);
}
assert_eq!(buffer.write_index, 0);
assert_eq!(buffer.entries[0].msg_len, "Message 0".len());
}
#[test]
fn test_trace_record_creation_valid() {
let msg = "Valid message";
let entry = TraceRecord::from(msg);
assert_eq!(entry.msg_len, msg.len());
assert_eq!(&entry.msg[..msg.len()], msg.as_bytes());
assert!(entry.cycles > 0); }
#[test]
fn test_trace_record_creation_too_long() {
let long_msg = "A".repeat(MAX_TRACE_MSG_LEN + 1);
let result = TraceRecord::from(long_msg.as_str());
assert_eq!(result.msg_len, MAX_TRACE_MSG_LEN);
assert_eq!(
&result.msg[..MAX_TRACE_MSG_LEN],
&long_msg.as_bytes()[..MAX_TRACE_MSG_LEN],
);
}
}
}