1use crate::control::ControlBlock;
33use crate::error::{Result, RingKernelError};
34use bytemuck::{Pod, Zeroable};
35
36pub const CONTROL_BLOCK_STATE_SIZE: usize = 24;
38
39pub const STATE_DESCRIPTOR_MAGIC: u32 = 0x54415453; pub trait EmbeddedState: Pod + Zeroable + Default + Copy + Send + Sync + 'static {
69 const VERSION: u32 = 1;
72
73 fn is_embedded() -> bool {
75 true
76 }
77}
78
79pub trait EmbeddedStateSize: EmbeddedState {
83 const SIZE_CHECK: () = assert!(
85 std::mem::size_of::<Self>() <= CONTROL_BLOCK_STATE_SIZE,
86 "EmbeddedState must fit in 24 bytes"
87 );
88}
89
90impl<T: EmbeddedState> EmbeddedStateSize for T {}
92
93#[derive(Debug, Clone, Copy, Default)]
102#[repr(C, align(8))]
103pub struct StateDescriptor {
104 pub magic: u32,
106 pub version: u32,
108 pub total_size: u64,
110 pub external_ptr: u64,
112}
113
114unsafe impl Zeroable for StateDescriptor {}
116unsafe impl Pod for StateDescriptor {}
117
118impl EmbeddedState for StateDescriptor {}
119
120const _: () = assert!(std::mem::size_of::<StateDescriptor>() == 24);
121
122impl StateDescriptor {
123 pub const fn new(version: u32, total_size: u64, external_ptr: u64) -> Self {
125 Self {
126 magic: STATE_DESCRIPTOR_MAGIC,
127 version,
128 total_size,
129 external_ptr,
130 }
131 }
132
133 pub fn is_valid(&self) -> bool {
135 self.magic == STATE_DESCRIPTOR_MAGIC
136 }
137
138 pub fn is_external(&self) -> bool {
140 self.is_valid() && self.external_ptr != 0
141 }
142
143 pub fn is_embedded(&self) -> bool {
145 !self.is_valid() || self.external_ptr == 0
146 }
147}
148
149pub trait GpuState: Send + Sync + 'static {
158 fn to_control_block_bytes(&self) -> Vec<u8>;
160
161 fn from_control_block_bytes(bytes: &[u8]) -> Result<Self>
163 where
164 Self: Sized;
165
166 fn state_version() -> u32 {
168 1
169 }
170
171 fn prefer_embedded() -> bool
173 where
174 Self: Sized,
175 {
176 std::mem::size_of::<Self>() <= CONTROL_BLOCK_STATE_SIZE
177 }
178}
179
180impl<T: EmbeddedState> GpuState for T {
182 fn to_control_block_bytes(&self) -> Vec<u8> {
183 bytemuck::bytes_of(self).to_vec()
184 }
185
186 fn from_control_block_bytes(bytes: &[u8]) -> Result<Self> {
187 if bytes.len() < std::mem::size_of::<Self>() {
188 return Err(RingKernelError::InvalidState {
189 expected: format!("{} bytes", std::mem::size_of::<Self>()),
190 actual: format!("{} bytes", bytes.len()),
191 });
192 }
193 Ok(*bytemuck::from_bytes(&bytes[..std::mem::size_of::<Self>()]))
194 }
195
196 fn state_version() -> u32 {
197 Self::VERSION
198 }
199
200 fn prefer_embedded() -> bool {
201 true
202 }
203}
204
205pub struct ControlBlockStateHelper;
211
212impl ControlBlockStateHelper {
213 pub fn write_embedded<S: EmbeddedState>(block: &mut ControlBlock, state: &S) -> Result<()> {
219 let bytes = bytemuck::bytes_of(state);
220 if bytes.len() > CONTROL_BLOCK_STATE_SIZE {
221 return Err(RingKernelError::InvalidState {
222 expected: format!("<= {} bytes", CONTROL_BLOCK_STATE_SIZE),
223 actual: format!("{} bytes", bytes.len()),
224 });
225 }
226
227 block._reserved = [0u8; 24];
229
230 block._reserved[..bytes.len()].copy_from_slice(bytes);
232
233 Ok(())
234 }
235
236 pub fn read_embedded<S: EmbeddedState>(block: &ControlBlock) -> Result<S> {
242 let size = std::mem::size_of::<S>();
243 if size > CONTROL_BLOCK_STATE_SIZE {
244 return Err(RingKernelError::InvalidState {
245 expected: format!("<= {} bytes", CONTROL_BLOCK_STATE_SIZE),
246 actual: format!("{} bytes", size),
247 });
248 }
249
250 Ok(*bytemuck::from_bytes(&block._reserved[..size]))
251 }
252
253 pub fn write_descriptor(block: &mut ControlBlock, descriptor: &StateDescriptor) -> Result<()> {
255 Self::write_embedded(block, descriptor)
256 }
257
258 pub fn read_descriptor(block: &ControlBlock) -> Option<StateDescriptor> {
262 let desc: StateDescriptor =
263 *bytemuck::from_bytes::<StateDescriptor>(&block._reserved[..24]);
264 if desc.is_valid() {
265 Some(desc)
266 } else {
267 None
268 }
269 }
270
271 pub fn has_embedded_state(block: &ControlBlock) -> bool {
273 match Self::read_descriptor(block) {
274 Some(desc) => desc.is_embedded(),
275 None => true, }
277 }
278
279 pub fn has_external_state(block: &ControlBlock) -> bool {
281 match Self::read_descriptor(block) {
282 Some(desc) => desc.is_external(),
283 None => false,
284 }
285 }
286
287 pub fn clear_state(block: &mut ControlBlock) {
289 block._reserved = [0u8; 24];
290 }
291
292 pub fn raw_bytes(block: &ControlBlock) -> &[u8; 24] {
294 &block._reserved
295 }
296
297 pub fn raw_bytes_mut(block: &mut ControlBlock) -> &mut [u8; 24] {
299 &mut block._reserved
300 }
301}
302
303#[derive(Debug, Clone)]
309pub struct StateSnapshot {
310 pub data: Vec<u8>,
312 pub version: u32,
314 pub was_embedded: bool,
316 pub kernel_id: u64,
318 pub timestamp: u64,
320}
321
322impl StateSnapshot {
323 pub fn new(data: Vec<u8>, version: u32, was_embedded: bool, kernel_id: u64) -> Self {
325 Self {
326 data,
327 version,
328 was_embedded,
329 kernel_id,
330 timestamp: 0,
331 }
332 }
333
334 pub fn with_timestamp(mut self, timestamp: u64) -> Self {
336 self.timestamp = timestamp;
337 self
338 }
339
340 pub fn restore<S: GpuState>(&self) -> Result<S> {
342 S::from_control_block_bytes(&self.data)
343 }
344}
345
346#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[derive(Default, Clone, Copy, Debug, PartialEq)]
356 #[repr(C, align(8))]
357 struct TestState {
358 value_a: u64,
359 value_b: u64,
360 counter: u32,
361 flags: u32,
362 }
363
364 unsafe impl Zeroable for TestState {}
366 unsafe impl Pod for TestState {}
367
368 impl EmbeddedState for TestState {}
369
370 #[derive(Default, Clone, Copy, Debug, PartialEq)]
372 #[repr(C)]
373 struct SmallState {
374 value: u64,
375 }
376
377 unsafe impl Zeroable for SmallState {}
378 unsafe impl Pod for SmallState {}
379
380 impl EmbeddedState for SmallState {}
381
382 #[test]
383 fn test_state_size_constant() {
384 assert_eq!(CONTROL_BLOCK_STATE_SIZE, 24);
385 }
386
387 #[test]
388 fn test_state_descriptor_size() {
389 assert_eq!(std::mem::size_of::<StateDescriptor>(), 24);
390 }
391
392 #[test]
393 fn test_state_descriptor_validation() {
394 let desc = StateDescriptor::new(1, 256, 0x1000);
395 assert!(desc.is_valid());
396 assert!(desc.is_external());
397 assert!(!desc.is_embedded());
398
399 let embedded_desc = StateDescriptor::new(1, 24, 0);
400 assert!(embedded_desc.is_valid());
401 assert!(!embedded_desc.is_external());
402 assert!(embedded_desc.is_embedded());
403
404 let invalid_desc = StateDescriptor::default();
405 assert!(!invalid_desc.is_valid());
406 }
407
408 #[test]
409 fn test_write_read_embedded_state() {
410 let mut block = ControlBlock::new();
411 let state = TestState {
412 value_a: 0x1234567890ABCDEF,
413 value_b: 0xFEDCBA0987654321,
414 counter: 42,
415 flags: 0xFF,
416 };
417
418 ControlBlockStateHelper::write_embedded(&mut block, &state).unwrap();
419 let restored: TestState = ControlBlockStateHelper::read_embedded(&block).unwrap();
420
421 assert_eq!(state, restored);
422 }
423
424 #[test]
425 fn test_write_read_small_state() {
426 let mut block = ControlBlock::new();
427 let state = SmallState { value: 42 };
428
429 ControlBlockStateHelper::write_embedded(&mut block, &state).unwrap();
430 let restored: SmallState = ControlBlockStateHelper::read_embedded(&block).unwrap();
431
432 assert_eq!(state, restored);
433 }
434
435 #[test]
436 fn test_write_read_descriptor() {
437 let mut block = ControlBlock::new();
438 let desc = StateDescriptor::new(2, 1024, 0xDEADBEEF);
439
440 ControlBlockStateHelper::write_descriptor(&mut block, &desc).unwrap();
441
442 let restored = ControlBlockStateHelper::read_descriptor(&block).unwrap();
443 assert_eq!(restored.magic, STATE_DESCRIPTOR_MAGIC);
444 assert_eq!(restored.version, 2);
445 assert_eq!(restored.total_size, 1024);
446 assert_eq!(restored.external_ptr, 0xDEADBEEF);
447 }
448
449 #[test]
450 fn test_has_embedded_external_state() {
451 let mut block = ControlBlock::new();
452
453 assert!(ControlBlockStateHelper::has_embedded_state(&block));
455 assert!(!ControlBlockStateHelper::has_external_state(&block));
456
457 let desc = StateDescriptor::new(1, 256, 0x1000);
459 ControlBlockStateHelper::write_descriptor(&mut block, &desc).unwrap();
460
461 assert!(!ControlBlockStateHelper::has_embedded_state(&block));
462 assert!(ControlBlockStateHelper::has_external_state(&block));
463
464 let desc = StateDescriptor::new(1, 24, 0);
466 ControlBlockStateHelper::write_descriptor(&mut block, &desc).unwrap();
467
468 assert!(ControlBlockStateHelper::has_embedded_state(&block));
469 assert!(!ControlBlockStateHelper::has_external_state(&block));
470 }
471
472 #[test]
473 fn test_clear_state() {
474 let mut block = ControlBlock::new();
475 let state = TestState {
476 value_a: 123,
477 value_b: 456,
478 counter: 789,
479 flags: 0xABC,
480 };
481
482 ControlBlockStateHelper::write_embedded(&mut block, &state).unwrap();
483 assert!(block._reserved.iter().any(|&b| b != 0));
484
485 ControlBlockStateHelper::clear_state(&mut block);
486 assert!(block._reserved.iter().all(|&b| b == 0));
487 }
488
489 #[test]
490 fn test_raw_bytes_access() {
491 let mut block = ControlBlock::new();
492 block._reserved[0] = 0x42;
493 block._reserved[23] = 0xFF;
494
495 let bytes = ControlBlockStateHelper::raw_bytes(&block);
496 assert_eq!(bytes[0], 0x42);
497 assert_eq!(bytes[23], 0xFF);
498
499 let bytes_mut = ControlBlockStateHelper::raw_bytes_mut(&mut block);
500 bytes_mut[1] = 0x99;
501 assert_eq!(block._reserved[1], 0x99);
502 }
503
504 #[test]
505 fn test_gpu_state_trait() {
506 let state = TestState {
507 value_a: 100,
508 value_b: 200,
509 counter: 300,
510 flags: 400,
511 };
512
513 let bytes = state.to_control_block_bytes();
514 assert_eq!(bytes.len(), 24);
515
516 let restored = TestState::from_control_block_bytes(&bytes).unwrap();
517 assert_eq!(state, restored);
518
519 assert!(TestState::prefer_embedded());
520 assert_eq!(TestState::state_version(), 1);
521 }
522
523 #[test]
524 fn test_state_snapshot() {
525 let state = TestState {
526 value_a: 1,
527 value_b: 2,
528 counter: 3,
529 flags: 4,
530 };
531
532 let snapshot =
533 StateSnapshot::new(state.to_control_block_bytes(), 1, true, 42).with_timestamp(1000);
534
535 assert_eq!(snapshot.version, 1);
536 assert!(snapshot.was_embedded);
537 assert_eq!(snapshot.kernel_id, 42);
538 assert_eq!(snapshot.timestamp, 1000);
539
540 let restored: TestState = snapshot.restore().unwrap();
541 assert_eq!(state, restored);
542 }
543
544 #[test]
545 fn test_embedded_state_size_check() {
546 assert_eq!(std::mem::size_of::<TestState>(), 24);
548 assert_eq!(<TestState as EmbeddedStateSize>::SIZE_CHECK, ());
550
551 assert!(std::mem::size_of::<SmallState>() <= CONTROL_BLOCK_STATE_SIZE);
553 assert_eq!(<SmallState as EmbeddedStateSize>::SIZE_CHECK, ());
555 }
556}