use crate::address::Address;
use crate::error::ProgramError;
pub const MAX_SEGMENT_BORROWS: usize = 16;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum AccessKind {
Read = 0,
Write = 1,
}
#[inline(always)]
fn address_fingerprint(address: &Address) -> u64 {
let bytes = address.as_array();
u64::from_le_bytes([
bytes[0], bytes[1], bytes[2], bytes[3],
bytes[4], bytes[5], bytes[6], bytes[7],
])
}
#[inline(always)]
fn address_eq(a: &Address, b: &Address) -> bool {
a.as_array() == b.as_array()
}
#[derive(Clone, Copy, Debug)]
pub struct SegmentBorrow {
pub key_fp: u64,
pub key: Address,
pub offset: u32,
pub size: u32,
pub kind: AccessKind,
}
#[inline(always)]
const fn ranges_overlap(a_off: u32, a_size: u32, b_off: u32, b_size: u32) -> bool {
let a_end = a_off + a_size;
let b_end = b_off + b_size;
!(a_end <= b_off || b_end <= a_off)
}
pub struct SegmentBorrowRegistry {
entries: [SegmentBorrow; MAX_SEGMENT_BORROWS],
len: u8,
}
impl SegmentBorrowRegistry {
#[inline(always)]
pub const fn new() -> Self {
const EMPTY: SegmentBorrow = SegmentBorrow {
key_fp: 0,
key: Address::new([0u8; 32]),
offset: 0,
size: 0,
kind: AccessKind::Read,
};
Self {
entries: [EMPTY; MAX_SEGMENT_BORROWS],
len: 0,
}
}
#[inline(always)]
pub const fn len(&self) -> usize {
self.len as usize
}
#[inline(always)]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
#[inline(always)]
pub fn register_leased_read(
&mut self,
key: &Address,
offset: u32,
size: u32,
) -> Result<SegmentBorrow, ProgramError> {
let borrow = SegmentBorrow {
key_fp: address_fingerprint(key),
key: *key,
offset,
size,
kind: AccessKind::Read,
};
self.register(borrow)?;
Ok(borrow)
}
#[inline(always)]
pub fn register_leased_write(
&mut self,
key: &Address,
offset: u32,
size: u32,
) -> Result<SegmentBorrow, ProgramError> {
let borrow = SegmentBorrow {
key_fp: address_fingerprint(key),
key: *key,
offset,
size,
kind: AccessKind::Write,
};
self.register(borrow)?;
Ok(borrow)
}
#[inline(always)]
pub fn register(&mut self, new: SegmentBorrow) -> Result<(), ProgramError> {
let len = self.len as usize;
if len >= MAX_SEGMENT_BORROWS {
return Err(ProgramError::AccountBorrowFailed);
}
let mut i = 0;
while i < len {
let existing = &self.entries[i];
if existing.key_fp == new.key_fp
&& address_eq(&existing.key, &new.key)
&& ranges_overlap(existing.offset, existing.size, new.offset, new.size)
{
match (existing.kind, new.kind) {
(AccessKind::Read, AccessKind::Read) => {}
_ => return Err(ProgramError::AccountBorrowFailed),
}
}
i += 1;
}
self.entries[len] = new;
self.len = (len + 1) as u8;
Ok(())
}
#[inline(always)]
pub fn register_read(
&mut self,
key: &Address,
offset: u32,
size: u32,
) -> Result<(), ProgramError> {
self.register(SegmentBorrow {
key_fp: address_fingerprint(key),
key: *key,
offset,
size,
kind: AccessKind::Read,
})
}
#[inline(always)]
pub fn register_write(
&mut self,
key: &Address,
offset: u32,
size: u32,
) -> Result<(), ProgramError> {
self.register(SegmentBorrow {
key_fp: address_fingerprint(key),
key: *key,
offset,
size,
kind: AccessKind::Write,
})
}
#[inline(always)]
pub fn release(&mut self, borrow: &SegmentBorrow) -> bool {
let len = self.len as usize;
let mut i = 0;
while i < len {
let existing = &self.entries[i];
if existing.key_fp == borrow.key_fp
&& address_eq(&existing.key, &borrow.key)
&& existing.offset == borrow.offset
&& existing.size == borrow.size
&& existing.kind == borrow.kind
{
let new_len = len - 1;
self.len = new_len as u8;
if i < new_len {
self.entries[i] = self.entries[new_len];
}
return true;
}
i += 1;
}
false
}
#[inline(always)]
pub fn clear(&mut self) {
self.len = 0;
}
#[inline(always)]
pub fn would_conflict(&self, proposed: &SegmentBorrow) -> bool {
let len = self.len as usize;
let mut i = 0;
while i < len {
let existing = &self.entries[i];
if existing.key_fp == proposed.key_fp
&& address_eq(&existing.key, &proposed.key)
&& ranges_overlap(existing.offset, existing.size, proposed.offset, proposed.size)
{
match (existing.kind, proposed.kind) {
(AccessKind::Read, AccessKind::Read) => {}
_ => return true,
}
}
i += 1;
}
false
}
#[inline(always)]
pub fn register_guard(
&mut self,
borrow: SegmentBorrow,
) -> Result<SegmentBorrowGuard<'_>, ProgramError> {
self.register(borrow)?;
Ok(SegmentBorrowGuard {
registry: self,
borrow,
})
}
#[inline(always)]
pub fn register_guard_read(
&mut self,
key: &Address,
offset: u32,
size: u32,
) -> Result<SegmentBorrowGuard<'_>, ProgramError> {
let borrow = SegmentBorrow {
key_fp: address_fingerprint(key),
key: *key,
offset,
size,
kind: AccessKind::Read,
};
self.register_guard(borrow)
}
#[inline(always)]
pub fn register_guard_write(
&mut self,
key: &Address,
offset: u32,
size: u32,
) -> Result<SegmentBorrowGuard<'_>, ProgramError> {
let borrow = SegmentBorrow {
key_fp: address_fingerprint(key),
key: *key,
offset,
size,
kind: AccessKind::Write,
};
self.register_guard(borrow)
}
#[inline]
pub fn for_each<F: FnMut(&SegmentBorrow)>(&self, mut f: F) {
let len = self.len as usize;
let mut i = 0;
while i < len {
f(&self.entries[i]);
i += 1;
}
}
#[inline]
pub fn find_exact(
&self,
key: &Address,
offset: u32,
size: u32,
kind: AccessKind,
) -> Option<&SegmentBorrow> {
let fp = address_fingerprint(key);
let len = self.len as usize;
let mut i = 0;
while i < len {
let e = &self.entries[i];
if e.key_fp == fp
&& address_eq(&e.key, key)
&& e.offset == offset
&& e.size == size
&& e.kind as u8 == kind as u8
{
return Some(e);
}
i += 1;
}
None
}
}
pub struct SegmentBorrowGuard<'a> {
registry: &'a mut SegmentBorrowRegistry,
borrow: SegmentBorrow,
}
impl<'a> SegmentBorrowGuard<'a> {
#[inline(always)]
pub fn kind(&self) -> AccessKind {
self.borrow.kind
}
#[inline(always)]
pub fn offset(&self) -> u32 {
self.borrow.offset
}
#[inline(always)]
pub fn size(&self) -> u32 {
self.borrow.size
}
}
impl<'a> Drop for SegmentBorrowGuard<'a> {
fn drop(&mut self) {
self.registry.release(&self.borrow);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Address;
fn test_addr(seed: u8) -> Address {
Address::new([seed; 32])
}
#[test]
fn read_read_same_range_allowed() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_read(&key, 0, 8).is_ok());
assert!(reg.register_read(&key, 0, 8).is_ok());
assert_eq!(reg.len(), 2);
}
#[test]
fn read_write_same_range_rejected() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_read(&key, 0, 8).is_ok());
assert!(reg.register_write(&key, 0, 8).is_err());
}
#[test]
fn write_write_same_range_rejected() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_write(&key, 0, 8).is_ok());
assert!(reg.register_write(&key, 0, 8).is_err());
}
#[test]
fn write_read_same_range_rejected() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_write(&key, 0, 8).is_ok());
assert!(reg.register_read(&key, 0, 8).is_err());
}
#[test]
fn non_overlapping_write_write_allowed() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_write(&key, 0, 8).is_ok());
assert!(reg.register_write(&key, 8, 32).is_ok());
}
#[test]
fn partially_overlapping_rejected() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_write(&key, 0, 16).is_ok());
assert!(reg.register_write(&key, 8, 16).is_err());
}
#[test]
fn different_accounts_always_allowed() {
let mut reg = SegmentBorrowRegistry::new();
assert!(reg.register_write(&test_addr(1), 0, 8).is_ok());
assert!(reg.register_write(&test_addr(2), 0, 8).is_ok());
}
#[test]
fn release_then_reacquire() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
let borrow = SegmentBorrow {
key_fp: address_fingerprint(&key),
key,
offset: 0,
size: 8,
kind: AccessKind::Write,
};
assert!(reg.register(borrow).is_ok());
assert!(reg.register_write(&key, 0, 8).is_err()); assert!(reg.release(&borrow));
assert!(reg.register_write(&key, 0, 8).is_ok()); }
#[test]
fn capacity_limit() {
let mut reg = SegmentBorrowRegistry::new();
for i in 0..MAX_SEGMENT_BORROWS {
assert!(reg.register_read(&test_addr(1), i as u32 * 8, 8).is_ok());
}
assert!(reg.register_read(&test_addr(1), 256, 8).is_err());
}
#[test]
fn would_conflict_does_not_mutate() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_write(&key, 0, 8).is_ok());
let proposed = SegmentBorrow {
key_fp: address_fingerprint(&key),
key,
offset: 0,
size: 8,
kind: AccessKind::Write,
};
assert!(reg.would_conflict(&proposed));
assert_eq!(reg.len(), 1); }
#[test]
fn adjacent_ranges_no_conflict() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
assert!(reg.register_write(&key, 0, 8).is_ok());
assert!(reg.register_write(&key, 8, 8).is_ok());
}
#[test]
fn guard_auto_releases_write_on_drop() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
{
let _guard = reg.register_guard_write(&key, 0, 8).unwrap();
}
assert_eq!(reg.len(), 0);
assert!(reg.register_write(&key, 0, 8).is_ok());
}
#[test]
fn guard_auto_releases_read_on_drop() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
{
let _guard = reg.register_guard_read(&key, 0, 8).unwrap();
}
assert_eq!(reg.len(), 0);
assert!(reg.register_write(&key, 0, 8).is_ok());
}
#[test]
fn sequential_guards_reuse_slot() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
for _ in 0..4 {
let _guard = reg.register_guard_write(&key, 0, 8).unwrap();
}
assert_eq!(reg.len(), 0);
}
#[test]
fn guard_accessors() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
let guard = reg.register_guard_write(&key, 16, 32).unwrap();
assert_eq!(guard.kind(), AccessKind::Write);
assert_eq!(guard.offset(), 16);
assert_eq!(guard.size(), 32);
}
#[test]
fn guard_then_manual_register_ok() {
let mut reg = SegmentBorrowRegistry::new();
let key = test_addr(1);
{
let _guard = reg.register_guard_write(&key, 0, 8).unwrap();
}
assert!(reg.register_read(&key, 0, 8).is_ok());
assert_eq!(reg.len(), 1);
}
}