use crate::buffer::MlxBuffer;
use metal::foreign_types::ForeignType;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum MemRangeRole {
Src,
Dst,
}
#[derive(Clone, Copy, Debug)]
pub struct BufferRange {
pub buf_id: usize,
pub p0: u64,
pub p1: u64,
pub role: MemRangeRole,
}
impl BufferRange {
#[inline]
pub fn from_buffer(buf: &MlxBuffer, role: MemRangeRole) -> Self {
let buf_id = buf.metal_buffer().as_ptr() as usize;
let base = buf.contents_ptr() as u64;
let p0 = base + buf.byte_offset();
let extent = (buf.byte_len() as u64).saturating_sub(buf.byte_offset());
let p1 = p0 + extent;
Self {
buf_id,
p0,
p1,
role,
}
}
#[inline]
pub fn conflicts_with(&self, other: &BufferRange) -> bool {
if self.buf_id != other.buf_id {
return false;
}
if self.role == MemRangeRole::Src && other.role == MemRangeRole::Src {
return false;
}
self.p0 < other.p1 && self.p1 >= other.p0
}
}
pub struct MemRanges {
ranges: Vec<BufferRange>,
checks: u64,
barriers_forced: u64,
}
impl Default for MemRanges {
fn default() -> Self {
Self::new()
}
}
impl MemRanges {
pub fn new() -> Self {
Self {
ranges: Vec::with_capacity(256),
checks: 0,
barriers_forced: 0,
}
}
#[inline]
pub fn reset(&mut self) {
self.ranges.clear();
}
#[inline]
pub fn len(&self) -> usize {
self.ranges.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ranges.is_empty()
}
#[inline]
pub fn checks(&self) -> u64 {
self.checks
}
#[inline]
pub fn barriers_forced(&self) -> u64 {
self.barriers_forced
}
#[inline]
pub fn push(&mut self, range: BufferRange) {
self.ranges.push(range);
}
pub fn add_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) {
for r in reads {
self.ranges
.push(BufferRange::from_buffer(r, MemRangeRole::Src));
}
for w in writes {
self.ranges
.push(BufferRange::from_buffer(w, MemRangeRole::Dst));
}
}
pub fn check_dispatch(&mut self, reads: &[&MlxBuffer], writes: &[&MlxBuffer]) -> bool {
self.checks += 1;
for r in reads {
let candidate = BufferRange::from_buffer(r, MemRangeRole::Src);
for existing in &self.ranges {
if candidate.conflicts_with(existing) {
self.barriers_forced += 1;
return false;
}
}
}
for w in writes {
let candidate = BufferRange::from_buffer(w, MemRangeRole::Dst);
for existing in &self.ranges {
if candidate.conflicts_with(existing) {
self.barriers_forced += 1;
return false;
}
}
}
true
}
pub fn check_and_record(
&mut self,
reads: &[&MlxBuffer],
writes: &[&MlxBuffer],
) -> bool {
let ok = self.check_dispatch(reads, writes);
if ok {
self.add_dispatch(reads, writes);
}
ok
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DType, MlxDevice};
fn dev() -> MlxDevice {
MlxDevice::new().expect("MlxDevice::new failed")
}
#[test]
fn read_read_same_buffer_no_conflict() {
let d = dev();
let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let mut mr = MemRanges::new();
let ok1 = mr.check_and_record(&[&a], &[]);
assert!(ok1, "first dispatch always ok");
let ok2 = mr.check_and_record(&[&a], &[]);
assert!(ok2, "RAR same-buffer must not conflict");
assert_eq!(mr.barriers_forced(), 0);
}
#[test]
fn raw_same_buffer_conflicts() {
let d = dev();
let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let mut mr = MemRanges::new();
assert!(mr.check_and_record(&[], &[&a]));
let ok = mr.check_and_record(&[&a], &[]);
assert!(!ok, "RAW same-buffer must force barrier");
assert_eq!(mr.barriers_forced(), 1);
}
#[test]
fn war_same_buffer_conflicts() {
let d = dev();
let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let mut mr = MemRanges::new();
assert!(mr.check_and_record(&[&a], &[]));
let ok = mr.check_and_record(&[], &[&a]);
assert!(!ok, "WAR same-buffer must force barrier");
assert_eq!(mr.barriers_forced(), 1);
}
#[test]
fn waw_same_buffer_conflicts() {
let d = dev();
let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let mut mr = MemRanges::new();
assert!(mr.check_and_record(&[], &[&a]));
let ok = mr.check_and_record(&[], &[&a]);
assert!(!ok, "WAW same-buffer must force barrier");
assert_eq!(mr.barriers_forced(), 1);
}
#[test]
fn different_buffers_never_conflict() {
let d = dev();
let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let b = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let c = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let mut mr = MemRanges::new();
assert!(mr.check_and_record(&[], &[&a]));
assert!(mr.check_and_record(&[&b], &[&b]));
assert!(mr.check_and_record(&[&c], &[]));
assert_eq!(mr.barriers_forced(), 0);
}
#[test]
fn reset_clears_state() {
let d = dev();
let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let mut mr = MemRanges::new();
assert!(mr.check_and_record(&[], &[&a]));
assert!(!mr.check_and_record(&[&a], &[]));
mr.reset();
assert!(mr.check_and_record(&[&a], &[]));
assert!(mr.check_and_record(&[&a], &[]));
assert_eq!(mr.barriers_forced(), 1);
}
#[test]
fn slices_of_same_parent_conservative() {
let d = dev();
let parent = d.alloc_buffer(1024, DType::F32, vec![256]).unwrap();
let lo = parent.slice_view(0, 128);
let hi = parent.slice_view(512, 128);
let mut mr = MemRanges::new();
assert!(mr.check_and_record(&[], &[&lo]));
let ok = mr.check_and_record(&[], &[&hi]);
assert!(!ok, "slice WAW currently conservative — see docstring");
}
#[test]
fn sequential_pattern_two_barriers() {
let d = dev();
let x = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let y = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let mut mr = MemRanges::new();
assert!(mr.check_and_record(&[], &[&x]));
assert!(!mr.check_dispatch(&[&x], &[]));
mr.reset();
mr.add_dispatch(&[&x], &[]);
assert!(mr.check_and_record(&[], &[&y]));
assert!(!mr.check_dispatch(&[&y], &[]));
mr.reset();
mr.add_dispatch(&[&y], &[]);
assert_eq!(mr.barriers_forced(), 2);
}
#[test]
fn conflict_is_symmetric() {
let d = dev();
let a = d.alloc_buffer(64, DType::F32, vec![16]).unwrap();
let r_src = BufferRange::from_buffer(&a, MemRangeRole::Src);
let r_dst = BufferRange::from_buffer(&a, MemRangeRole::Dst);
assert!(r_src.conflicts_with(&r_dst));
assert!(r_dst.conflicts_with(&r_src));
}
}