use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
use super::resource::{
Access, AllocTag, BlockId, DeviceBlock, DeviceMemoryResource, Generation, ResourceError,
ResourceResult, StreamId,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum LogAction {
Allocate,
Deallocate,
ReapPending,
}
impl fmt::Display for LogAction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LogAction::Allocate => f.write_str("allocate"),
LogAction::Deallocate => f.write_str("deallocate"),
LogAction::ReapPending => f.write_str("reap_pending"),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum LogResult {
Ok,
Err { kind: &'static str, message: String },
}
impl LogResult {
pub fn from_result<T>(r: &ResourceResult<T>) -> Self {
match r {
Ok(_) => LogResult::Ok,
Err(e) => LogResult::Err {
kind: classify_error(e),
message: truncate(format!("{}", e), 256),
},
}
}
}
fn classify_error(e: &ResourceError) -> &'static str {
match e {
ResourceError::OutOfBudget { .. } => "OutOfBudget",
ResourceError::Driver(_) => "Driver",
ResourceError::StreamMisuse(_) => "StreamMisuse",
ResourceError::UseAfterFree { .. } => "UseAfterFree",
ResourceError::OutOfBounds { .. } => "OutOfBounds",
}
}
fn truncate(mut s: String, cap: usize) -> String {
if s.len() > cap {
let mut end = cap;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
s.truncate(end);
s.push('…');
}
s
}
#[derive(Clone, Debug)]
pub struct LogRecord {
pub action: LogAction,
pub device_ordinal: u32,
pub stream_id: Option<StreamId>,
pub ptr: Option<u64>,
pub bytes: Option<usize>,
pub tag: Option<AllocTag>,
pub generation: Option<Generation>,
pub thread_id: u64,
pub order_counter: u64,
pub timestamp_nanos: u128,
pub result: LogResult,
}
static ORDER_COUNTER: AtomicU64 = AtomicU64::new(1);
fn next_order_counter() -> u64 {
ORDER_COUNTER.fetch_add(1, Ordering::Relaxed)
}
fn now_nanos() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
}
fn current_thread_id_u64() -> u64 {
let s = format!("{:?}", std::thread::current().id());
let mut h: u64 = 0xcbf2_9ce4_8422_2325;
for b in s.as_bytes() {
h ^= *b as u64;
h = h.wrapping_mul(0x100_0000_01b3);
}
h
}
#[derive(Debug)]
pub enum SinkError {
Refused(String),
Io(String),
}
impl fmt::Display for SinkError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SinkError::Refused(m) => write!(f, "log sink refused record: {}", m),
SinkError::Io(m) => write!(f, "log sink io error: {}", m),
}
}
}
impl std::error::Error for SinkError {}
pub trait LoggingSink: Send + Sync {
fn emit(&self, record: LogRecord) -> Result<(), SinkError>;
}
pub struct InMemorySink {
records: Mutex<Vec<LogRecord>>,
}
impl InMemorySink {
pub fn new() -> Self {
Self {
records: Mutex::new(Vec::new()),
}
}
pub fn snapshot(&self) -> Vec<LogRecord> {
self.records.lock().expect("InMemorySink poisoned").clone()
}
pub fn clear(&self) {
self.records.lock().expect("InMemorySink poisoned").clear();
}
pub fn len(&self) -> usize {
self.records.lock().expect("InMemorySink poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for InMemorySink {
fn default() -> Self {
Self::new()
}
}
impl LoggingSink for InMemorySink {
fn emit(&self, record: LogRecord) -> Result<(), SinkError> {
self.records
.lock()
.expect("InMemorySink poisoned")
.push(record);
Ok(())
}
}
pub struct NullSink;
impl NullSink {
pub fn new() -> Self {
Self
}
}
impl Default for NullSink {
fn default() -> Self {
Self
}
}
impl LoggingSink for NullSink {
fn emit(&self, _record: LogRecord) -> Result<(), SinkError> {
Ok(())
}
}
pub struct LoggingResource {
inner: Box<dyn DeviceMemoryResource + Send + Sync>,
sink: std::sync::Arc<dyn LoggingSink>,
dropped_records: AtomicU64,
}
impl LoggingResource {
pub fn new(
inner: Box<dyn DeviceMemoryResource + Send + Sync>,
sink: std::sync::Arc<dyn LoggingSink>,
) -> Self {
Self {
inner,
sink,
dropped_records: AtomicU64::new(0),
}
}
pub fn dropped_records(&self) -> u64 {
self.dropped_records.load(Ordering::Relaxed)
}
fn emit(&self, record: LogRecord) {
if self.sink.emit(record).is_err() {
self.dropped_records.fetch_add(1, Ordering::Relaxed);
}
}
}
impl DeviceMemoryResource for LoggingResource {
fn allocate(
&self,
bytes: usize,
stream: StreamId,
tag: AllocTag,
) -> ResourceResult<DeviceBlock> {
let result = self.inner.allocate(bytes, stream, tag);
let (ptr, gen, recorded_bytes) = match &result {
Ok(b) => (Some(b.ptr), Some(b.generation), Some(b.bytes)),
Err(_) => (None, None, Some(bytes)),
};
self.emit(LogRecord {
action: LogAction::Allocate,
device_ordinal: self.inner.device_ordinal(),
stream_id: Some(stream),
ptr,
bytes: recorded_bytes,
tag: Some(tag),
generation: gen,
thread_id: current_thread_id_u64(),
order_counter: next_order_counter(),
timestamp_nanos: now_nanos(),
result: LogResult::from_result(&result),
});
result
}
fn deallocate(&self, block: DeviceBlock) -> ResourceResult<()> {
let ptr = block.ptr;
let bytes = block.bytes;
let tag = block.tag;
let gen = block.generation;
let stream = block.alloc_stream;
let dev = block.device_ordinal;
let result = self.inner.deallocate(block);
self.emit(LogRecord {
action: LogAction::Deallocate,
device_ordinal: dev,
stream_id: Some(stream),
ptr: Some(ptr),
bytes: Some(bytes),
tag: Some(tag),
generation: Some(gen),
thread_id: current_thread_id_u64(),
order_counter: next_order_counter(),
timestamp_nanos: now_nanos(),
result: LogResult::from_result(&result),
});
result
}
fn device_ordinal(&self) -> u32 {
self.inner.device_ordinal()
}
fn bytes_outstanding(&self) -> usize {
self.inner.bytes_outstanding()
}
fn reap_pending(&self) -> ResourceResult<()> {
let result = self.inner.reap_pending();
self.emit(LogRecord {
action: LogAction::ReapPending,
device_ordinal: self.inner.device_ordinal(),
stream_id: None,
ptr: None,
bytes: None,
tag: None,
generation: None,
thread_id: current_thread_id_u64(),
order_counter: next_order_counter(),
timestamp_nanos: now_nanos(),
result: LogResult::from_result(&result),
});
result
}
fn record_block_use(&self, block: &DeviceBlock, use_stream: StreamId) -> ResourceResult<()> {
self.inner.record_block_use(block, use_stream)
}
fn supports_block_use_tracking(&self) -> bool {
self.inner.supports_block_use_tracking()
}
fn prepare_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
self.inner.prepare_block_use(block, use_stream, access)
}
fn finish_block_use(
&self,
block: BlockId,
use_stream: StreamId,
access: Access,
) -> ResourceResult<()> {
self.inner.finish_block_use(block, use_stream, access)
}
}
#[cfg(test)]
mod tests {
use super::super::direct::DirectCudaResource;
use super::super::resource::BlockState;
use super::*;
use std::sync::Arc;
use crate::CudaDevice;
fn try_device() -> Option<Arc<CudaDevice>> {
CudaDevice::new(0).ok().map(Arc::new)
}
#[test]
fn pass_through_alloc_dealloc_emits_two_ok_records() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let sink = Arc::new(InMemorySink::new());
let r = LoggingResource::new(inner, sink.clone());
let block = r
.allocate(1024, StreamId::DEFAULT, AllocTag("logging-test"))
.expect("alloc");
assert_eq!(block.bytes, 1024);
assert_eq!(block.state, BlockState::Live);
r.deallocate(block).expect("dealloc");
let recs = sink.snapshot();
assert_eq!(recs.len(), 2, "expected 2 records, got {:?}", recs);
assert_eq!(recs[0].action, LogAction::Allocate);
assert_eq!(recs[0].result, LogResult::Ok);
assert_eq!(recs[0].bytes, Some(1024));
assert_eq!(recs[0].stream_id, Some(StreamId::DEFAULT));
assert!(recs[0].ptr.is_some());
assert_eq!(recs[1].action, LogAction::Deallocate);
assert_eq!(recs[1].result, LogResult::Ok);
assert_eq!(recs[1].ptr, recs[0].ptr);
assert_eq!(recs[1].generation, recs[0].generation);
}
#[test]
fn order_counter_strictly_increases_across_records() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let sink = Arc::new(InMemorySink::new());
let r = LoggingResource::new(inner, sink.clone());
for _ in 0..4 {
let b = r
.allocate(64, StreamId::DEFAULT, AllocTag::UNTAGGED)
.expect("alloc");
r.deallocate(b).expect("dealloc");
}
r.reap_pending().expect("reap");
let recs = sink.snapshot();
assert_eq!(recs.len(), 9); let mut last = 0u64;
for rec in &recs {
assert!(
rec.order_counter > last,
"order_counter must strictly increase: prev={}, now={}",
last,
rec.order_counter
);
last = rec.order_counter;
}
}
#[test]
fn failed_alloc_records_error_result() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let sink = Arc::new(InMemorySink::new());
let r = LoggingResource::new(inner, sink.clone());
let _ = r.allocate(0, StreamId::DEFAULT, AllocTag::UNTAGGED);
let recs = sink.snapshot();
assert_eq!(recs.len(), 1);
assert_eq!(recs[0].action, LogAction::Allocate);
assert!(matches!(recs[0].result, LogResult::Err { kind, .. } if kind == "Driver"));
assert_eq!(recs[0].bytes, Some(0));
assert!(recs[0].ptr.is_none());
assert!(recs[0].generation.is_none());
}
#[test]
fn failed_dealloc_records_error_result() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let sink = Arc::new(InMemorySink::new());
let r = LoggingResource::new(inner, sink.clone());
let bogus = DeviceBlock {
ptr: 0xdead_beef,
device_ordinal: 0,
alloc_stream: StreamId::DEFAULT,
bytes: 16,
align: 1,
tag: AllocTag::UNTAGGED,
generation: Generation::next(),
state: BlockState::Live,
};
let res = r.deallocate(bogus);
assert!(res.is_err());
let recs = sink.snapshot();
assert_eq!(recs.len(), 1);
assert_eq!(recs[0].action, LogAction::Deallocate);
assert!(matches!(recs[0].result, LogResult::Err { kind, .. } if kind == "UseAfterFree"));
assert_eq!(recs[0].ptr, Some(0xdead_beef));
}
#[test]
fn reap_pending_emits_record() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let sink = Arc::new(InMemorySink::new());
let r = LoggingResource::new(inner, sink.clone());
r.reap_pending().expect("reap");
let recs = sink.snapshot();
assert_eq!(recs.len(), 1);
assert_eq!(recs[0].action, LogAction::ReapPending);
assert_eq!(recs[0].result, LogResult::Ok);
assert!(recs[0].stream_id.is_none());
assert!(recs[0].ptr.is_none());
}
#[test]
fn sink_failure_increments_dropped_records_but_does_not_break_alloc() {
struct RefuseAllSink;
impl LoggingSink for RefuseAllSink {
fn emit(&self, _r: LogRecord) -> Result<(), SinkError> {
Err(SinkError::Refused("test sink refuses all".into()))
}
}
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let sink = Arc::new(RefuseAllSink);
let r = LoggingResource::new(inner, sink);
let block = r
.allocate(128, StreamId::DEFAULT, AllocTag("refuse-test"))
.expect("alloc must succeed even when sink refuses");
r.deallocate(block).expect("dealloc must succeed too");
assert_eq!(r.dropped_records(), 2);
}
#[test]
fn forwards_bytes_outstanding_and_device_ordinal() {
let Some(device) = try_device() else {
return;
};
let inner = Box::new(DirectCudaResource::new(Arc::clone(&device), 3));
let sink = Arc::new(InMemorySink::new());
let r = LoggingResource::new(inner, sink);
assert_eq!(r.device_ordinal(), 3);
assert_eq!(r.bytes_outstanding(), 0);
}
#[test]
fn truncate_handles_non_ascii_at_cap_without_panicking() {
let s = String::from("héllo");
let out = truncate(s, 2);
assert!(out.starts_with('h'));
assert!(out.ends_with('…'));
let out = truncate(String::from("ok"), 100);
assert_eq!(out, "ok");
let out = truncate(String::from("héllo"), 0);
assert_eq!(out, "…");
let out = truncate(String::from("abcdefgh"), 3);
assert_eq!(out, "abc…");
}
#[test]
fn null_sink_accepts_records_without_retention() {
let sink = NullSink::new();
let rec = LogRecord {
action: LogAction::Allocate,
device_ordinal: 0,
stream_id: Some(StreamId::DEFAULT),
ptr: Some(0xdead_beef),
bytes: Some(64),
tag: Some(AllocTag::UNTAGGED),
generation: Some(Generation::next()),
thread_id: 0,
order_counter: 1,
timestamp_nanos: 0,
result: LogResult::Ok,
};
sink.emit(rec).expect("NullSink never errors");
}
}