use crate::backend::{BackendError, DispatchConfig};
use vyre_foundation::ir::Program;
pub trait DeviceBuffer: std::any::Any + Send + Sync + std::fmt::Debug {
fn backend_id(&self) -> &'static str;
fn byte_len(&self) -> usize;
fn debug_label(&self) -> Option<&str> {
None
}
fn as_any(&self) -> &dyn std::any::Any;
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
}
pub const DEVICE_BUFFER_FEATURE: &str = "DeviceBuffer";
pub(crate) fn unsupported_device_buffer(backend_id: &'static str) -> BackendError {
BackendError::UnsupportedFeature {
name: DEVICE_BUFFER_FEATURE.to_string(),
backend: backend_id.to_string(),
}
}
#[derive(Debug)]
pub struct HostShimBuffer {
backend_id: &'static str,
bytes: Vec<u8>,
label: Option<String>,
}
impl HostShimBuffer {
#[must_use]
pub fn allocate(backend_id: &'static str, byte_len: usize) -> Box<dyn DeviceBuffer> {
Box::new(Self {
backend_id,
bytes: vec![0; byte_len],
label: None,
})
}
#[must_use]
pub fn from_bytes(backend_id: &'static str, bytes: Vec<u8>) -> Box<dyn DeviceBuffer> {
Box::new(Self {
backend_id,
bytes,
label: None,
})
}
#[must_use]
pub fn as_slice(&self) -> &[u8] {
&self.bytes
}
#[must_use]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.bytes
}
pub fn set_label(&mut self, label: impl Into<String>) {
self.label = Some(label.into());
}
}
impl DeviceBuffer for HostShimBuffer {
fn backend_id(&self) -> &'static str {
self.backend_id
}
fn byte_len(&self) -> usize {
self.bytes.len()
}
fn debug_label(&self) -> Option<&str> {
self.label.as_deref()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
pub fn validate_buffer_ownership<'a>(
self_backend_id: &str,
buffers: impl IntoIterator<Item = &'a dyn DeviceBuffer>,
) -> Result<(), BackendError> {
for (idx, buffer) in buffers.into_iter().enumerate() {
if buffer.backend_id() != self_backend_id {
return Err(BackendError::UnsupportedFeature {
name: format!(
"DeviceBuffer cross-backend dispatch (buffer {idx} owned by `{}`)",
buffer.backend_id()
),
backend: self_backend_id.to_string(),
});
}
}
Ok(())
}
pub fn default_dispatch_with_device_buffers(
backend: &dyn crate::backend::VyreBackend,
program: &Program,
inputs: &[&dyn DeviceBuffer],
outputs: &mut [&mut dyn DeviceBuffer],
config: &DispatchConfig,
) -> Result<(), BackendError> {
let _ = (program, config);
validate_buffer_ownership(backend.id(), inputs.iter().copied())?;
validate_buffer_ownership(
backend.id(),
outputs.iter().map(|b| &**b as &dyn DeviceBuffer),
)?;
Err(BackendError::UnsupportedFeature {
name: "DeviceBuffer dispatch requires a backend-native resident-buffer implementation; host-shim dispatch fallback is forbidden".to_string(),
backend: backend.id().to_string(),
})
}
const _ASSERT_DYN_SAFE: Option<&dyn DeviceBuffer> = None;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn host_shim_buffer_reports_size_and_backend() {
let buf = HostShimBuffer::allocate("test-backend", 64);
assert_eq!(buf.backend_id(), "test-backend");
assert_eq!(buf.byte_len(), 64);
assert!(buf.debug_label().is_none());
}
#[test]
fn host_shim_buffer_round_trips_bytes() {
let mut buf = HostShimBuffer::allocate("test-backend", 8);
let shim = buf
.as_any_mut()
.downcast_mut::<HostShimBuffer>()
.expect("Fix: HostShimBuffer");
shim.as_mut_slice()
.copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
let shim_ref = buf
.as_any()
.downcast_ref::<HostShimBuffer>()
.expect("Fix: HostShimBuffer");
assert_eq!(shim_ref.as_slice(), &[1, 2, 3, 4, 5, 6, 7, 8]);
}
#[test]
fn validate_buffer_ownership_rejects_cross_backend() {
let cuda_buf = HostShimBuffer::allocate("cuda", 4);
let wgpu_buf = HostShimBuffer::allocate("wgpu", 4);
let result =
validate_buffer_ownership("cuda", [cuda_buf.as_ref(), wgpu_buf.as_ref()].into_iter());
assert!(matches!(
result,
Err(BackendError::UnsupportedFeature { .. })
));
}
#[test]
fn validate_buffer_ownership_accepts_same_backend() {
let a = HostShimBuffer::allocate("cuda", 4);
let b = HostShimBuffer::allocate("cuda", 8);
validate_buffer_ownership("cuda", [a.as_ref(), b.as_ref()].into_iter())
.expect("Fix: same-backend buffers must validate");
}
#[test]
fn unsupported_device_buffer_marks_feature_correctly() {
let err = unsupported_device_buffer("test-backend");
match err {
BackendError::UnsupportedFeature { name, backend } => {
assert_eq!(name, DEVICE_BUFFER_FEATURE);
assert_eq!(backend, "test-backend");
}
other => panic!("unexpected variant: {other:?}"),
}
}
}