use std::collections::HashMap;
use crate::device_runtime::{
Access, BlockId, DeviceBlock, Generation, ResourceError, ResourceResult, StreamId,
XlogDeviceRuntime,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum RecorderMode {
Permissive,
Strict,
}
pub struct LaunchRecorder {
launch_stream: StreamId,
mode: RecorderMode,
uses: Vec<RecordedUse>,
strict_reject: Option<ResourceError>,
preflighted: bool,
committed: bool,
}
#[derive(Clone, Copy)]
struct RecordedUse {
block: BlockId,
access: Access,
#[allow(dead_code)]
label: &'static str,
}
impl LaunchRecorder {
pub fn new_permissive(launch_stream: StreamId) -> Self {
Self::new(launch_stream, RecorderMode::Permissive)
}
pub fn new_strict(launch_stream: StreamId) -> Self {
Self::new(launch_stream, RecorderMode::Strict)
}
fn new(launch_stream: StreamId, mode: RecorderMode) -> Self {
Self {
launch_stream,
mode,
uses: Vec::new(),
strict_reject: None,
preflighted: false,
committed: false,
}
}
pub fn launch_stream(&self) -> StreamId {
self.launch_stream
}
pub fn mode(&self) -> RecorderMode {
self.mode
}
fn note(
&mut self,
label: &'static str,
block: Option<&DeviceBlock>,
access: Access,
external: bool,
) -> &mut Self {
if self.preflighted && self.strict_reject.is_none() {
self.strict_reject = Some(ResourceError::StreamMisuse(format!(
"LaunchRecorder::{}: recorded after preflight — once preflight \
succeeds, the set of uses is frozen so commit-time discoveries \
cannot leave unprotected work in flight. Record this use BEFORE \
preflight (the recorder is lifetime-free; snapshots release the \
source borrow immediately, so kernel-param &mut borrows still \
work)",
label,
)));
return self;
}
if let Some(b) = block {
self.uses.push(RecordedUse {
block: BlockId::from_block(b),
access,
label,
});
return self;
}
if self.mode == RecorderMode::Strict && self.strict_reject.is_none() {
let why = if external {
"external (DLPack / ArrowDevice) memory has no runtime identity; \
strict launch recorders cannot attach a cross-stream use to it. \
Use a permissive recorder OR coordinate the cross-stream \
synchronization explicitly outside xlog"
} else {
"buffer is legacy cudarc-backed (no runtime block); strict launch \
recorders require the allocation to be routed through \
GpuMemoryManager::with_runtime so a DeviceBlock is available"
};
self.strict_reject = Some(ResourceError::StreamMisuse(format!(
"LaunchRecorder::{}: untracked buffer rejected — {}",
label, why
)));
}
self
}
pub fn read<T: cudarc::driver::DeviceRepr>(
&mut self,
slice: &crate::memory::TrackedCudaSlice<T>,
) -> &mut Self {
self.note("read", slice.runtime_block(), Access::Read, false)
}
pub fn write<T: cudarc::driver::DeviceRepr>(
&mut self,
slice: &crate::memory::TrackedCudaSlice<T>,
) -> &mut Self {
self.note("write", slice.runtime_block(), Access::Write, false)
}
pub fn read_write<T: cudarc::driver::DeviceRepr>(
&mut self,
slice: &crate::memory::TrackedCudaSlice<T>,
) -> &mut Self {
self.note(
"read_write",
slice.runtime_block(),
Access::ReadWrite,
false,
)
}
pub fn read_column(&mut self, col: &crate::memory::CudaColumn) -> &mut Self {
self.note(
"read_column",
col.runtime_block(),
Access::Read,
col.is_external(),
)
}
pub fn write_column(&mut self, col: &crate::memory::CudaColumn) -> &mut Self {
self.note(
"write_column",
col.runtime_block(),
Access::Write,
col.is_external(),
)
}
#[allow(dead_code)]
pub(crate) fn read_view_runtime(&mut self, block: Option<&DeviceBlock>) -> &mut Self {
self.note("read_view", block, Access::Read, false)
}
pub fn recorded_count(&self) -> usize {
self.uses.len()
}
pub fn preflight(&mut self, runtime: &XlogDeviceRuntime) -> ResourceResult<()> {
if let Some(err) = &self.strict_reject {
return Err(ResourceError::StreamMisuse(format!("{}", err)));
}
if !self.uses.is_empty() && !runtime.supports_block_use_tracking() {
return Err(ResourceError::StreamMisuse(
"LaunchRecorder::preflight: active resource does not support \
cross-stream use tracking. Build the runtime around \
AsyncCudaResource (or a decorator stack over it) for \
stream-lifetime-safe launches"
.to_string(),
));
}
let deduped = dedup_uses(&self.uses);
for use_ in &deduped {
runtime.prepare_block_use(use_.block, self.launch_stream, use_.access)?;
}
self.preflighted = true;
Ok(())
}
pub fn commit(mut self, runtime: &XlogDeviceRuntime) -> ResourceResult<()> {
if let Some(err) = self.strict_reject.take() {
return Err(err);
}
if !self.uses.is_empty() && !self.preflighted {
return Err(ResourceError::StreamMisuse(
"LaunchRecorder::commit: non-empty recorder reached commit without \
a successful preflight. The caller MUST call preflight(&runtime) \
BEFORE enqueueing CUDA work; otherwise commit-time failures leave \
unprotected work in flight. See the preflight + commit contract \
in the LaunchRecorder doc"
.to_string(),
));
}
let deduped = dedup_uses(&self.uses);
for use_ in &deduped {
runtime.finish_block_use(use_.block, self.launch_stream, use_.access)?;
}
self.committed = true;
Ok(())
}
}
fn dedup_uses(uses: &[RecordedUse]) -> Vec<RecordedUse> {
let mut by_id: HashMap<(u64, Generation, u32), usize> = HashMap::with_capacity(uses.len());
let mut deduped: Vec<RecordedUse> = Vec::with_capacity(uses.len());
for use_ in uses {
let key = (
use_.block.ptr,
use_.block.generation,
use_.block.device_ordinal,
);
match by_id.get(&key) {
Some(&idx) => {
deduped[idx].access = combine_access(deduped[idx].access, use_.access);
}
None => {
by_id.insert(key, deduped.len());
deduped.push(*use_);
}
}
}
deduped
}
fn combine_access(a: Access, b: Access) -> Access {
match (a, b) {
(Access::ReadWrite, _) | (_, Access::ReadWrite) => Access::ReadWrite,
(Access::Read, Access::Write) | (Access::Write, Access::Read) => Access::ReadWrite,
(Access::Read, Access::Read) => Access::Read,
(Access::Write, Access::Write) => Access::Write,
}
}
impl Drop for LaunchRecorder {
fn drop(&mut self) {
if !self.committed && !self.uses.is_empty() {
#[cfg(debug_assertions)]
eprintln!(
"[xlog_cuda::launch] LaunchRecorder dropped without commit: \
{} uses on launch_stream={} (mode={:?}) were NOT recorded; \
cross-stream lifetime safety lost for this launch",
self.uses.len(),
self.launch_stream.0,
self.mode,
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device_runtime::{
AsyncCudaResource, DeviceMemoryResource, DirectCudaResource, StreamPool,
};
use crate::CudaDevice;
use std::sync::Arc;
use xlog_core::MemoryBudget;
fn try_async_runtime() -> Option<(Arc<CudaDevice>, Arc<XlogDeviceRuntime>, StreamId)> {
let device = Arc::new(CudaDevice::new(0).ok()?);
let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
let async_resource: Box<dyn DeviceMemoryResource + Send + Sync> = Box::new(
AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool)),
);
let runtime = Arc::new(XlogDeviceRuntime::with_resource(
Arc::clone(&device),
0,
Arc::clone(&pool),
async_resource,
));
let launch_stream = pool.acquire().ok()?;
Some((device, runtime, launch_stream))
}
fn try_direct_runtime() -> Option<(Arc<CudaDevice>, Arc<XlogDeviceRuntime>, StreamId)> {
let device = Arc::new(CudaDevice::new(0).ok()?);
let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
let direct: Box<dyn DeviceMemoryResource + Send + Sync> =
Box::new(DirectCudaResource::new(Arc::clone(&device), 0));
let runtime = Arc::new(XlogDeviceRuntime::with_resource(
Arc::clone(&device),
0,
Arc::clone(&pool),
direct,
));
Some((device, runtime, StreamId::DEFAULT))
}
#[test]
fn empty_commit_is_ok_in_both_modes() {
let Some((_d, rt, ls)) = try_async_runtime() else {
return;
};
LaunchRecorder::new_permissive(ls)
.commit(&rt)
.expect("permissive empty");
LaunchRecorder::new_strict(ls)
.commit(&rt)
.expect("strict empty");
}
#[test]
fn permissive_skips_legacy_silently() {
let Some(device) = CudaDevice::new(0).ok().map(Arc::new) else {
return;
};
let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
let async_resource: Box<dyn DeviceMemoryResource + Send + Sync> = Box::new(
AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool)),
);
let runtime = Arc::new(XlogDeviceRuntime::with_resource(
Arc::clone(&device),
0,
Arc::clone(&pool),
async_resource,
));
let launch_stream = pool.acquire().expect("acquire");
let manager = Arc::new(crate::GpuMemoryManager::new(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
));
let legacy = manager.alloc::<u8>(64).expect("legacy alloc");
assert!(legacy.runtime_block().is_none());
let mut rec = LaunchRecorder::new_permissive(launch_stream);
rec.read(&legacy);
assert_eq!(rec.recorded_count(), 0);
rec.preflight(&runtime).expect("permissive preflight");
rec.commit(&runtime).expect("permissive commit");
}
#[test]
fn strict_rejects_legacy_at_preflight() {
let Some((device, runtime, launch_stream)) = try_async_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::new(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
));
let legacy = manager.alloc::<u8>(64).expect("legacy alloc");
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read(&legacy);
let err = rec.preflight(&runtime);
match err {
Err(ResourceError::StreamMisuse(msg)) => {
assert!(msg.contains("untracked buffer rejected"), "msg: {}", msg);
}
other => panic!(
"strict mode must reject untracked buffer at preflight; got {:?}",
other
),
}
}
#[test]
fn preflight_rejects_direct_runtime_before_enqueue() {
let Some((device, runtime, launch_stream)) = try_direct_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
Arc::clone(&runtime),
));
let buf = manager.alloc::<u8>(64).expect("alloc");
assert!(buf.runtime_block().is_some());
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read(&buf);
let err = rec.preflight(&runtime);
match err {
Err(ResourceError::StreamMisuse(msg)) => {
assert!(
msg.contains("does not support cross-stream use tracking"),
"msg: {}",
msg
);
}
other => panic!(
"preflight must reject Direct-backed runtime before enqueue; got {:?}",
other
),
}
}
#[test]
fn preflight_then_commit_async_runtime() {
let Some((device, runtime, launch_stream)) = try_async_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
Arc::clone(&runtime),
));
let buf = manager.alloc::<u8>(64).expect("alloc");
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read(&buf);
rec.preflight(&runtime).expect("preflight ok");
rec.commit(&runtime).expect("commit ok");
}
#[test]
fn commit_rejects_un_preflighted_strict_recorder() {
let Some((device, runtime, launch_stream)) = try_async_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
Arc::clone(&runtime),
));
let buf = manager.alloc::<u8>(64).expect("alloc");
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read(&buf);
let err = rec.commit(&runtime);
match err {
Err(ResourceError::StreamMisuse(msg)) => {
assert!(
msg.contains("without a successful preflight"),
"msg: {}",
msg
);
}
other => panic!(
"non-empty un-preflighted commit must return StreamMisuse, got {:?}",
other
),
}
}
#[test]
fn empty_recorder_commit_without_preflight_is_ok() {
let Some((_d, rt, ls)) = try_async_runtime() else {
return;
};
LaunchRecorder::new_strict(ls)
.commit(&rt)
.expect("empty strict commit without preflight");
}
#[test]
fn note_after_preflight_via_standard_method_is_rejected() {
let Some((device, runtime, launch_stream)) = try_async_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
Arc::clone(&runtime),
));
let buf_a = manager.alloc::<u8>(64).expect("alloc a");
let buf_b = manager.alloc::<u8>(64).expect("alloc b");
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read(&buf_a);
rec.preflight(&runtime).expect("preflight ok");
rec.read(&buf_b);
let err = rec.commit(&runtime);
match err {
Err(ResourceError::StreamMisuse(msg)) => {
assert!(msg.contains("recorded after preflight"), "msg: {}", msg);
}
other => panic!(
"post-preflight standard-method record must be rejected; got {:?}",
other
),
}
}
#[test]
fn pre_preflight_fresh_write_is_accepted() {
let Some((device, runtime, launch_stream)) = try_async_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
Arc::clone(&runtime),
));
let buf_a = manager.alloc::<u8>(64).expect("alloc a");
let mut buf_fresh = manager.alloc::<u8>(64).expect("alloc fresh");
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read(&buf_a);
rec.write(&buf_fresh);
rec.preflight(&runtime).expect("preflight ok");
let _kernel_param = &mut buf_fresh;
rec.commit(&runtime).expect("commit ok");
}
#[test]
fn read_then_write_same_block_dedupes_to_read_write() {
let Some((device, runtime, launch_stream)) = try_async_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
Arc::clone(&runtime),
));
let buf = manager.alloc::<u8>(64).expect("alloc");
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read(&buf);
rec.write(&buf);
rec.preflight(&runtime).expect("preflight");
rec.commit(&runtime).expect("commit");
}
#[test]
fn dedup_keys_on_full_block_id_not_ptr_alone() {
let block_a = BlockId {
ptr: 0xdead_beef,
generation: Generation(1),
alloc_stream: StreamId::DEFAULT,
device_ordinal: 0,
};
let block_b = BlockId {
ptr: 0xdead_beef,
generation: Generation(2),
alloc_stream: StreamId::DEFAULT,
device_ordinal: 0,
};
let uses = vec![
RecordedUse {
block: block_a,
access: Access::Read,
label: "read",
},
RecordedUse {
block: block_b,
access: Access::Write,
label: "write",
},
];
let deduped = dedup_uses(&uses);
assert_eq!(deduped.len(), 2, "ABA generations must NOT collapse");
assert_eq!(deduped[0].block.generation, Generation(1));
assert_eq!(deduped[0].access, Access::Read);
assert_eq!(deduped[1].block.generation, Generation(2));
assert_eq!(deduped[1].access, Access::Write);
let same_id = vec![
RecordedUse {
block: block_a,
access: Access::Read,
label: "read",
},
RecordedUse {
block: block_a,
access: Access::Write,
label: "write",
},
];
let collapsed = dedup_uses(&same_id);
assert_eq!(collapsed.len(), 1);
assert_eq!(collapsed[0].access, Access::ReadWrite);
}
#[test]
fn read_column_owned_runtime_backed() {
use crate::memory::CudaColumn;
let Some((device, runtime, launch_stream)) = try_async_runtime() else {
return;
};
let manager = Arc::new(crate::GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024),
Arc::clone(&runtime),
));
let slice = manager.alloc::<u8>(64).expect("alloc");
let col = CudaColumn::owned(slice);
assert!(col.runtime_block().is_some());
let mut rec = LaunchRecorder::new_strict(launch_stream);
rec.read_column(&col);
assert_eq!(rec.recorded_count(), 1);
rec.preflight(&runtime).expect("preflight");
rec.commit(&runtime).expect("commit");
}
}