use ruvix_types::{KernelError, MsgPriority, RegionHandle, RegionPolicy};
use crate::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct MessageDescriptor {
pub region: RegionHandle,
pub offset: u64,
pub length: u32,
_padding: u32,
}
impl MessageDescriptor {
pub const SIZE: usize = core::mem::size_of::<Self>();
#[inline]
pub const fn new(region: RegionHandle, offset: u64, length: u32) -> Self {
Self {
region,
offset,
length,
_padding: 0,
}
}
#[inline]
pub fn is_valid(&self) -> bool {
!self.region.is_null() && self.length > 0
}
#[inline]
pub fn to_bytes(&self) -> [u8; Self::SIZE] {
unsafe { core::mem::transmute(*self) }
}
#[inline]
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
if bytes.len() < Self::SIZE {
return None;
}
let mut arr = [0u8; Self::SIZE];
arr.copy_from_slice(&bytes[..Self::SIZE]);
Some(unsafe { core::mem::transmute(arr) })
}
}
pub struct DescriptorValidator {
allow_immutable: bool,
allow_append_only: bool,
}
impl DescriptorValidator {
pub const fn new() -> Self {
Self {
allow_immutable: true,
allow_append_only: true,
}
}
pub const fn immutable_only() -> Self {
Self {
allow_immutable: true,
allow_append_only: false,
}
}
pub fn validate_policy(&self, policy: &RegionPolicy) -> Result<()> {
match policy {
RegionPolicy::Immutable if self.allow_immutable => Ok(()),
RegionPolicy::AppendOnly { .. } if self.allow_append_only => Ok(()),
RegionPolicy::Slab { .. } => {
Err(KernelError::InvalidDescriptorRegion)
}
_ => Err(KernelError::InvalidDescriptorRegion),
}
}
pub fn validate_bounds(&self, descriptor: &MessageDescriptor, region_size: usize) -> Result<()> {
let end = descriptor
.offset
.checked_add(descriptor.length as u64)
.ok_or(KernelError::InvalidArgument)?;
if end > region_size as u64 {
return Err(KernelError::InvalidArgument);
}
Ok(())
}
pub fn validate(
&self,
descriptor: &MessageDescriptor,
policy: &RegionPolicy,
region_size: usize,
) -> Result<()> {
if !descriptor.is_valid() {
return Err(KernelError::InvalidArgument);
}
self.validate_policy(policy)?;
self.validate_bounds(descriptor, region_size)?;
Ok(())
}
}
impl Default for DescriptorValidator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct PrioritizedDescriptor {
pub descriptor: MessageDescriptor,
pub priority: MsgPriority,
}
impl PrioritizedDescriptor {
#[inline]
pub const fn new(descriptor: MessageDescriptor, priority: MsgPriority) -> Self {
Self {
descriptor,
priority,
}
}
#[inline]
pub const fn with_normal_priority(descriptor: MessageDescriptor) -> Self {
Self {
descriptor,
priority: MsgPriority::Normal,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ruvix_types::Handle;
fn test_region() -> RegionHandle {
RegionHandle(Handle::new(1, 0))
}
#[test]
fn test_descriptor_size() {
assert_eq!(MessageDescriptor::SIZE, 24);
}
#[test]
fn test_descriptor_roundtrip() {
let desc = MessageDescriptor::new(test_region(), 100, 256);
let bytes = desc.to_bytes();
let recovered = MessageDescriptor::from_bytes(&bytes).unwrap();
assert_eq!(desc.region, recovered.region);
assert_eq!(desc.offset, recovered.offset);
assert_eq!(desc.length, recovered.length);
}
#[test]
fn test_descriptor_validation_null() {
let desc = MessageDescriptor::new(RegionHandle::null(), 0, 100);
assert!(!desc.is_valid());
}
#[test]
fn test_descriptor_validation_zero_length() {
let desc = MessageDescriptor::new(test_region(), 0, 0);
assert!(!desc.is_valid());
}
#[test]
fn test_validator_immutable_ok() {
let validator = DescriptorValidator::new();
let result = validator.validate_policy(&RegionPolicy::Immutable);
assert!(result.is_ok());
}
#[test]
fn test_validator_append_only_ok() {
let validator = DescriptorValidator::new();
let result = validator.validate_policy(&RegionPolicy::AppendOnly { max_size: 1024 });
assert!(result.is_ok());
}
#[test]
fn test_validator_slab_rejected() {
let validator = DescriptorValidator::new();
let result = validator.validate_policy(&RegionPolicy::Slab {
slot_size: 64,
slot_count: 16,
});
assert!(matches!(result, Err(KernelError::InvalidDescriptorRegion)));
}
#[test]
fn test_validator_bounds() {
let validator = DescriptorValidator::new();
let desc = MessageDescriptor::new(test_region(), 100, 256);
assert!(validator.validate_bounds(&desc, 500).is_ok());
assert!(validator.validate_bounds(&desc, 356).is_ok());
assert!(validator.validate_bounds(&desc, 355).is_err());
}
#[test]
fn test_validator_bounds_overflow() {
let validator = DescriptorValidator::new();
let desc = MessageDescriptor::new(test_region(), u64::MAX - 10, 100);
assert!(validator.validate_bounds(&desc, 1000).is_err());
}
#[test]
fn test_full_validation() {
let validator = DescriptorValidator::new();
let desc = MessageDescriptor::new(test_region(), 100, 256);
assert!(validator
.validate(&desc, &RegionPolicy::Immutable, 500)
.is_ok());
assert!(validator
.validate(
&desc,
&RegionPolicy::Slab {
slot_size: 64,
slot_count: 16
},
500
)
.is_err());
}
#[test]
fn test_immutable_only_validator() {
let validator = DescriptorValidator::immutable_only();
assert!(validator.validate_policy(&RegionPolicy::Immutable).is_ok());
assert!(validator
.validate_policy(&RegionPolicy::AppendOnly { max_size: 1024 })
.is_err());
}
}