#![allow(unused_unsafe)]
use crate::wireguard_nt_raw;
use std::{alloc::Layout, sync::Arc};
pub(crate) struct UnsafeHandle<T>(pub T);
unsafe impl<T> Send for UnsafeHandle<T> {}
unsafe impl<T> Sync for UnsafeHandle<T> {}
pub fn get_running_driver_version(wireguard: &Arc<wireguard_nt_raw::wireguard>) -> u32 {
unsafe { wireguard.WireGuardGetRunningDriverVersion() }
}
pub(crate) struct StructWriter {
start: *mut u8,
offset: usize,
layout: Layout,
}
impl StructWriter {
pub fn new(capacity: usize, align: usize) -> Self {
let layout = Layout::from_size_align(capacity, align).unwrap();
let start = unsafe { std::alloc::alloc(layout) };
unsafe { start.write_bytes(0, capacity) };
Self {
start,
offset: 0,
layout,
}
}
pub unsafe fn write<T>(&mut self) -> &mut T {
let size = std::mem::size_of::<T>();
if size + self.offset > self.layout.size() {
panic!(
"Overflow attempting to write struct of size {}. To allocation size: {}, offset: {}",
size,
self.layout.size(),
self.offset
);
}
let ptr = unsafe { self.start.add(self.offset) };
self.offset += size;
assert_eq!(ptr as usize % std::mem::align_of::<T>(), 0);
unsafe { &mut *ptr.cast::<T>() }
}
pub fn ptr(&self) -> *const u8 {
self.start
}
pub fn is_full(&self) -> bool {
self.layout.size() == self.offset
}
}
impl Drop for StructWriter {
fn drop(&mut self) {
unsafe { std::alloc::dealloc(self.start, self.layout) };
}
}
pub(crate) struct StructReader {
start: *mut u8,
offset: usize,
layout: Layout,
}
impl StructReader {
pub fn new(capacity: usize, align: usize) -> Self {
let layout = Layout::from_size_align(capacity, align).unwrap();
let start = unsafe { std::alloc::alloc(layout) };
Self {
start,
offset: 0,
layout,
}
}
pub unsafe fn read<T>(&mut self) -> &T {
let size = std::mem::size_of::<T>();
if size + self.offset > self.layout.size() {
panic!(
"Overflow attempting to read struct of size {}. To allocation size: {}, offset: {}",
size,
self.layout.size(),
self.offset
);
}
let ptr = unsafe { self.start.add(self.offset) };
self.offset += size;
assert_eq!(ptr as usize % std::mem::align_of::<T>(), 0);
unsafe { &*ptr.cast::<T>() }
}
pub fn ptr_mut(&self) -> *mut u8 {
self.start
}
#[allow(dead_code)]
pub fn is_full(&self) -> bool {
self.layout.size() == self.offset
}
}
impl Drop for StructReader {
fn drop(&mut self) {
unsafe { std::alloc::dealloc(self.start, self.layout) };
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem::{align_of_val, size_of_val};
#[test]
fn reader_basic() {
#[derive(Debug)]
#[repr(C)]
struct Data {
field_a: u8,
field_b: u32,
}
let expected_data = Data {
field_a: 0b10000001,
field_b: 0x00FFFF00,
};
let mut reader =
StructReader::new(size_of_val(&expected_data), align_of_val(&expected_data));
let byte_buffer: &mut [u8; 8] = unsafe { &mut *(reader.ptr_mut() as *mut [u8; 8]) };
byte_buffer[0] = 0b10000001;
byte_buffer[4] = 0x0;
byte_buffer[5] = 0xFF;
byte_buffer[6] = 0xFF;
byte_buffer[7] = 0x0;
let actual_data: &Data = unsafe { reader.read() };
assert_eq!(actual_data.field_a, expected_data.field_a);
assert_eq!(actual_data.field_b, expected_data.field_b);
}
#[test]
#[should_panic]
fn reader_overflow() {
unsafe { StructReader::new(1, align_of_val(&1)).read::<u64>() };
}
#[test]
fn writer_basic() {
let mut buf = StructWriter::new(20, 4);
*unsafe { buf.write::<u16>() } = 0;
*unsafe { buf.write::<u16>() } = 0xCCCC;
*unsafe { buf.write::<u32>() } = 0x00FFFF00;
*unsafe { buf.write::<u32>() } = 0x80808080;
*unsafe { buf.write::<u32>() } = 0xFFFFFFFF;
let slice: &[u8] = unsafe { std::slice::from_raw_parts(buf.ptr(), 16) };
assert_eq!(
slice,
&[0, 0, 204, 204, 0, 255, 255, 0, 128, 128, 128, 128, 255, 255, 255, 255]
);
}
#[test]
#[should_panic]
fn writer_unaligned() {
let mut buf = StructWriter::new(8, 4);
*unsafe { buf.write::<u8>() } = 0;
*unsafe { buf.write::<u32>() } = 0xFFFFFFFF;
}
#[test]
#[should_panic]
fn writer_overflow() {
let mut buf = StructWriter::new(16, 4);
*unsafe { buf.write::<u32>() } = 0xFFFFFFFF;
*unsafe { buf.write::<u32>() } = 0xFFFFFF00;
*unsafe { buf.write::<u32>() } = 0xFFFF00FF;
*unsafe { buf.write::<u32>() } = 0xFF00FFFF;
*unsafe { buf.write::<u8>() } = 5;
}
}