use crate::control::ControlBlock;
use crate::error::{Result, RingKernelError};
use bytemuck::{Pod, Zeroable};
pub const CONTROL_BLOCK_STATE_SIZE: usize = 24;
pub const STATE_DESCRIPTOR_MAGIC: u32 = 0x54415453;
pub trait EmbeddedState: Pod + Zeroable + Default + Copy + Send + Sync + 'static {
const VERSION: u32 = 1;
fn is_embedded() -> bool {
true
}
}
pub trait EmbeddedStateSize: EmbeddedState {
const SIZE_CHECK: () = assert!(
std::mem::size_of::<Self>() <= CONTROL_BLOCK_STATE_SIZE,
"EmbeddedState must fit in 24 bytes"
);
}
impl<T: EmbeddedState> EmbeddedStateSize for T {}
#[derive(Debug, Clone, Copy, Default)]
#[repr(C, align(8))]
pub struct StateDescriptor {
pub magic: u32,
pub version: u32,
pub total_size: u64,
pub external_ptr: u64,
}
unsafe impl Zeroable for StateDescriptor {}
unsafe impl Pod for StateDescriptor {}
impl EmbeddedState for StateDescriptor {}
const _: () = assert!(std::mem::size_of::<StateDescriptor>() == 24);
impl StateDescriptor {
pub const fn new(version: u32, total_size: u64, external_ptr: u64) -> Self {
Self {
magic: STATE_DESCRIPTOR_MAGIC,
version,
total_size,
external_ptr,
}
}
pub fn is_valid(&self) -> bool {
self.magic == STATE_DESCRIPTOR_MAGIC
}
pub fn is_external(&self) -> bool {
self.is_valid() && self.external_ptr != 0
}
pub fn is_embedded(&self) -> bool {
!self.is_valid() || self.external_ptr == 0
}
}
pub trait GpuState: Send + Sync + 'static {
fn to_control_block_bytes(&self) -> Vec<u8>;
fn from_control_block_bytes(bytes: &[u8]) -> Result<Self>
where
Self: Sized;
fn state_version() -> u32 {
1
}
fn prefer_embedded() -> bool
where
Self: Sized,
{
std::mem::size_of::<Self>() <= CONTROL_BLOCK_STATE_SIZE
}
}
impl<T: EmbeddedState> GpuState for T {
fn to_control_block_bytes(&self) -> Vec<u8> {
bytemuck::bytes_of(self).to_vec()
}
fn from_control_block_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() < std::mem::size_of::<Self>() {
return Err(RingKernelError::InvalidState {
expected: format!("{} bytes", std::mem::size_of::<Self>()),
actual: format!("{} bytes", bytes.len()),
});
}
Ok(*bytemuck::from_bytes(&bytes[..std::mem::size_of::<Self>()]))
}
fn state_version() -> u32 {
Self::VERSION
}
fn prefer_embedded() -> bool {
true
}
}
pub struct ControlBlockStateHelper;
impl ControlBlockStateHelper {
pub fn write_embedded<S: EmbeddedState>(block: &mut ControlBlock, state: &S) -> Result<()> {
let bytes = bytemuck::bytes_of(state);
if bytes.len() > CONTROL_BLOCK_STATE_SIZE {
return Err(RingKernelError::InvalidState {
expected: format!("<= {} bytes", CONTROL_BLOCK_STATE_SIZE),
actual: format!("{} bytes", bytes.len()),
});
}
block._reserved = [0u8; 24];
block._reserved[..bytes.len()].copy_from_slice(bytes);
Ok(())
}
pub fn read_embedded<S: EmbeddedState>(block: &ControlBlock) -> Result<S> {
let size = std::mem::size_of::<S>();
if size > CONTROL_BLOCK_STATE_SIZE {
return Err(RingKernelError::InvalidState {
expected: format!("<= {} bytes", CONTROL_BLOCK_STATE_SIZE),
actual: format!("{} bytes", size),
});
}
Ok(*bytemuck::from_bytes(&block._reserved[..size]))
}
pub fn write_descriptor(block: &mut ControlBlock, descriptor: &StateDescriptor) -> Result<()> {
Self::write_embedded(block, descriptor)
}
pub fn read_descriptor(block: &ControlBlock) -> Option<StateDescriptor> {
let desc: StateDescriptor =
*bytemuck::from_bytes::<StateDescriptor>(&block._reserved[..24]);
if desc.is_valid() {
Some(desc)
} else {
None
}
}
pub fn has_embedded_state(block: &ControlBlock) -> bool {
match Self::read_descriptor(block) {
Some(desc) => desc.is_embedded(),
None => true, }
}
pub fn has_external_state(block: &ControlBlock) -> bool {
match Self::read_descriptor(block) {
Some(desc) => desc.is_external(),
None => false,
}
}
pub fn clear_state(block: &mut ControlBlock) {
block._reserved = [0u8; 24];
}
pub fn raw_bytes(block: &ControlBlock) -> &[u8; 24] {
&block._reserved
}
pub fn raw_bytes_mut(block: &mut ControlBlock) -> &mut [u8; 24] {
&mut block._reserved
}
}
#[derive(Debug, Clone)]
pub struct StateSnapshot {
pub data: Vec<u8>,
pub version: u32,
pub was_embedded: bool,
pub kernel_id: u64,
pub timestamp: u64,
}
impl StateSnapshot {
pub fn new(data: Vec<u8>, version: u32, was_embedded: bool, kernel_id: u64) -> Self {
Self {
data,
version,
was_embedded,
kernel_id,
timestamp: 0,
}
}
pub fn with_timestamp(mut self, timestamp: u64) -> Self {
self.timestamp = timestamp;
self
}
pub fn restore<S: GpuState>(&self) -> Result<S> {
S::from_control_block_bytes(&self.data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Default, Clone, Copy, Debug, PartialEq)]
#[repr(C, align(8))]
struct TestState {
value_a: u64,
value_b: u64,
counter: u32,
flags: u32,
}
unsafe impl Zeroable for TestState {}
unsafe impl Pod for TestState {}
impl EmbeddedState for TestState {}
#[derive(Default, Clone, Copy, Debug, PartialEq)]
#[repr(C)]
struct SmallState {
value: u64,
}
unsafe impl Zeroable for SmallState {}
unsafe impl Pod for SmallState {}
impl EmbeddedState for SmallState {}
#[test]
fn test_state_size_constant() {
assert_eq!(CONTROL_BLOCK_STATE_SIZE, 24);
}
#[test]
fn test_state_descriptor_size() {
assert_eq!(std::mem::size_of::<StateDescriptor>(), 24);
}
#[test]
fn test_state_descriptor_validation() {
let desc = StateDescriptor::new(1, 256, 0x1000);
assert!(desc.is_valid());
assert!(desc.is_external());
assert!(!desc.is_embedded());
let embedded_desc = StateDescriptor::new(1, 24, 0);
assert!(embedded_desc.is_valid());
assert!(!embedded_desc.is_external());
assert!(embedded_desc.is_embedded());
let invalid_desc = StateDescriptor::default();
assert!(!invalid_desc.is_valid());
}
#[test]
fn test_write_read_embedded_state() {
let mut block = ControlBlock::new();
let state = TestState {
value_a: 0x1234567890ABCDEF,
value_b: 0xFEDCBA0987654321,
counter: 42,
flags: 0xFF,
};
ControlBlockStateHelper::write_embedded(&mut block, &state).unwrap();
let restored: TestState = ControlBlockStateHelper::read_embedded(&block).unwrap();
assert_eq!(state, restored);
}
#[test]
fn test_write_read_small_state() {
let mut block = ControlBlock::new();
let state = SmallState { value: 42 };
ControlBlockStateHelper::write_embedded(&mut block, &state).unwrap();
let restored: SmallState = ControlBlockStateHelper::read_embedded(&block).unwrap();
assert_eq!(state, restored);
}
#[test]
fn test_write_read_descriptor() {
let mut block = ControlBlock::new();
let desc = StateDescriptor::new(2, 1024, 0xDEADBEEF);
ControlBlockStateHelper::write_descriptor(&mut block, &desc).unwrap();
let restored = ControlBlockStateHelper::read_descriptor(&block).unwrap();
assert_eq!(restored.magic, STATE_DESCRIPTOR_MAGIC);
assert_eq!(restored.version, 2);
assert_eq!(restored.total_size, 1024);
assert_eq!(restored.external_ptr, 0xDEADBEEF);
}
#[test]
fn test_has_embedded_external_state() {
let mut block = ControlBlock::new();
assert!(ControlBlockStateHelper::has_embedded_state(&block));
assert!(!ControlBlockStateHelper::has_external_state(&block));
let desc = StateDescriptor::new(1, 256, 0x1000);
ControlBlockStateHelper::write_descriptor(&mut block, &desc).unwrap();
assert!(!ControlBlockStateHelper::has_embedded_state(&block));
assert!(ControlBlockStateHelper::has_external_state(&block));
let desc = StateDescriptor::new(1, 24, 0);
ControlBlockStateHelper::write_descriptor(&mut block, &desc).unwrap();
assert!(ControlBlockStateHelper::has_embedded_state(&block));
assert!(!ControlBlockStateHelper::has_external_state(&block));
}
#[test]
fn test_clear_state() {
let mut block = ControlBlock::new();
let state = TestState {
value_a: 123,
value_b: 456,
counter: 789,
flags: 0xABC,
};
ControlBlockStateHelper::write_embedded(&mut block, &state).unwrap();
assert!(block._reserved.iter().any(|&b| b != 0));
ControlBlockStateHelper::clear_state(&mut block);
assert!(block._reserved.iter().all(|&b| b == 0));
}
#[test]
fn test_raw_bytes_access() {
let mut block = ControlBlock::new();
block._reserved[0] = 0x42;
block._reserved[23] = 0xFF;
let bytes = ControlBlockStateHelper::raw_bytes(&block);
assert_eq!(bytes[0], 0x42);
assert_eq!(bytes[23], 0xFF);
let bytes_mut = ControlBlockStateHelper::raw_bytes_mut(&mut block);
bytes_mut[1] = 0x99;
assert_eq!(block._reserved[1], 0x99);
}
#[test]
fn test_gpu_state_trait() {
let state = TestState {
value_a: 100,
value_b: 200,
counter: 300,
flags: 400,
};
let bytes = state.to_control_block_bytes();
assert_eq!(bytes.len(), 24);
let restored = TestState::from_control_block_bytes(&bytes).unwrap();
assert_eq!(state, restored);
assert!(TestState::prefer_embedded());
assert_eq!(TestState::state_version(), 1);
}
#[test]
fn test_state_snapshot() {
let state = TestState {
value_a: 1,
value_b: 2,
counter: 3,
flags: 4,
};
let snapshot =
StateSnapshot::new(state.to_control_block_bytes(), 1, true, 42).with_timestamp(1000);
assert_eq!(snapshot.version, 1);
assert!(snapshot.was_embedded);
assert_eq!(snapshot.kernel_id, 42);
assert_eq!(snapshot.timestamp, 1000);
let restored: TestState = snapshot.restore().unwrap();
assert_eq!(state, restored);
}
#[test]
fn test_embedded_state_size_check() {
assert_eq!(std::mem::size_of::<TestState>(), 24);
assert_eq!(<TestState as EmbeddedStateSize>::SIZE_CHECK, ());
assert!(std::mem::size_of::<SmallState>() <= CONTROL_BLOCK_STATE_SIZE);
assert_eq!(<SmallState as EmbeddedStateSize>::SIZE_CHECK, ());
}
}