use std::alloc::{alloc_zeroed, dealloc, realloc, Layout};
use std::any::TypeId;
use std::fmt::{Debug, Formatter};
use std::mem::size_of;
use bytemuck::{Pod, Zeroable};
use zune_core::bit_depth::BitType;
pub const MIN_ALIGNMENT: usize = 64;
#[derive(Copy, Clone)]
pub enum ChannelErrors {
UnalignedPointer(usize, usize),
UnevenLength(usize, usize),
DifferentType(TypeId, TypeId)
}
impl Debug for ChannelErrors {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ChannelErrors::UnalignedPointer(expected, found) => {
writeln!(f, "Channel pointer {expected} is not aligned to {found}")
}
ChannelErrors::UnevenLength(length, size_of_1) => {
writeln!(
f,
"Size of {size_of_1} cannot evenly divide length {length}"
)
}
ChannelErrors::DifferentType(expected, found) => {
writeln!(f, "Different type id {:?} from expected {:?}. This indicates you are converting a channel
to a type it wasn't instantiated with", expected, found)
}
}
}
}
#[derive(Eq)]
pub struct Channel {
ptr: *mut u8,
length: usize,
capacity: usize,
type_id: TypeId
}
unsafe impl Send for Channel {}
unsafe impl Sync for Channel {}
impl Clone for Channel {
fn clone(&self) -> Self {
let mut new_channel = Channel::new_with_capacity_and_type(self.capacity(), self.type_id);
unsafe {
new_channel.extend_unchecked(self.reinterpret_as_unchecked::<u8>());
}
new_channel
}
}
impl PartialEq for Channel {
fn eq(&self, other: &Self) -> bool {
if self.length != other.length {
return false;
}
if self.type_id != other.type_id {
return false;
}
unsafe {
let us = self.reinterpret_as_unchecked::<u8>();
let them = other.reinterpret_as_unchecked::<u8>();
for (a, b) in us.iter().zip(them) {
if *a != *b {
return false;
}
}
}
true
}
}
impl Debug for Channel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let slice = unsafe { std::slice::from_raw_parts(self.ptr, self.length) };
writeln!(f, "raw_bytes: {slice:?}")
}
}
impl Channel {
pub const fn capacity(&self) -> usize {
self.capacity
}
pub const fn len(&self) -> usize {
self.length
}
pub const fn is_empty(&self) -> bool {
self.length == 0
}
unsafe fn alloc(size: usize) -> *mut u8 {
let layout = Layout::from_size_align(size, MIN_ALIGNMENT).unwrap();
alloc_zeroed(layout)
}
unsafe fn realloc(&mut self, new_size: usize) {
let layout = Layout::from_size_align(new_size, MIN_ALIGNMENT).unwrap();
self.ptr = realloc(self.ptr, layout, new_size);
self.capacity = new_size;
}
unsafe fn dealloc(&mut self) {
let layout = Layout::from_size_align(self.capacity, MIN_ALIGNMENT).unwrap();
dealloc(self.ptr, layout);
}
pub fn new<T: 'static + Zeroable>() -> Channel {
Self::new_with_capacity::<T>(10)
}
pub fn new_with_length<T: 'static + Zeroable>(length: usize) -> Channel {
let mut channel = Channel::new_with_capacity::<T>(length);
channel.length = length;
channel
}
pub fn new_with_length_and_type(length: usize, type_id: TypeId) -> Channel {
let mut channel = Channel::new_with_capacity_and_type(length, type_id);
channel.length = length;
channel
}
pub fn new_with_bit_type(length: usize, depth: BitType) -> Channel {
let t_r = match depth {
BitType::U8 => TypeId::of::<u8>(),
BitType::U16 => TypeId::of::<u16>(),
BitType::F32 => TypeId::of::<f32>(),
_ => unimplemented!("Bit-depth :{:?}", depth)
};
Self::new_with_length_and_type(length, t_r)
}
pub fn get_type_id(&self) -> TypeId {
self.type_id
}
pub fn new_with_capacity<T: 'static + Zeroable>(capacity: usize) -> Channel {
Self::new_with_capacity_and_type(capacity, TypeId::of::<T>())
}
pub(crate) fn new_with_capacity_and_type(capacity: usize, type_id: TypeId) -> Channel {
let ptr = unsafe { Self::alloc(capacity) };
Self {
ptr,
length: 0,
capacity,
type_id
}
}
pub fn from_elm<T>(length: usize, elm: T) -> Channel
where
T: Clone + Copy + 'static + Zeroable + Pod
{
let mut new_chan = Channel::new_with_length::<T>(length * size_of::<T>());
new_chan.fill(elm).unwrap();
new_chan
}
fn has_capacity(&self, extra: usize) -> bool {
self.length.saturating_add(extra) <= self.capacity
}
pub fn extend<T: Copy + 'static + Zeroable>(&mut self, data: &[T]) {
assert_eq!(
TypeId::of::<T>(),
self.type_id,
"Type Id's do not match, trying to extend the channel
with a type it wasn't created with"
);
unsafe {
self.extend_unchecked(data);
}
}
unsafe fn extend_unchecked<T: Copy + 'static + Zeroable>(&mut self, data: &[T]) {
let data_size = core::mem::size_of::<T>();
let items = data.len().saturating_mul(data_size);
if !self.has_capacity(items) {
self.realloc(self.capacity.saturating_add(items).saturating_add(10));
}
self.ptr.wrapping_add(self.length).copy_from(
data.as_ptr().cast::<u8>(),
data.len().saturating_mul(data_size)
);
self.length = self.length.checked_add(items).unwrap();
}
pub fn reinterpret_as<T: Default + 'static>(&self) -> Result<&[T], ChannelErrors> {
self.confirm_suspicions::<T>()?;
Ok(unsafe { self.reinterpret_as_unchecked() })
}
unsafe fn reinterpret_as_unchecked<T: Default + 'static>(&self) -> &[T] {
let new_slice = unsafe { std::slice::from_raw_parts_mut::<u8>(self.ptr, self.length) };
let (a, b, c) = new_slice.align_to();
assert!(a.is_empty(), "extra sloppy bytes");
assert!(c.is_empty(), "extra sloppy bytes");
b
}
pub fn reinterpret_as_mut<T: 'static + Pod>(&mut self) -> Result<&mut [T], ChannelErrors> {
self.confirm_suspicions::<T>()?;
let new_slice = unsafe { std::slice::from_raw_parts_mut::<u8>(self.ptr, self.length) };
let (a, b, c) = bytemuck::pod_align_to_mut(new_slice);
assert!(a.is_empty(), "extra sloppy bytes");
assert!(c.is_empty(), "extra sloppy bytes");
Ok(b)
}
pub fn push<T: Copy + 'static + Zeroable>(&mut self, elm: T) {
let size = core::mem::size_of::<T>();
if !self.has_capacity(size) {
unsafe {
self.realloc(self.capacity.saturating_mul(size.saturating_mul(3)) / 2);
}
}
unsafe {
let arr = [elm];
self.ptr
.add(self.length)
.copy_from(arr.as_ptr().cast(), size);
}
self.length += size;
}
pub fn fill<T>(&mut self, element: T) -> Result<(), ChannelErrors>
where
T: Clone + Copy + 'static + Pod
{
let array = self.reinterpret_as_mut()?;
array.fill(element);
Ok(())
}
fn confirm_suspicions<T: 'static>(&self) -> Result<(), ChannelErrors> {
if !is_aligned::<T>(self.ptr) {
return Err(ChannelErrors::UnalignedPointer(
self.ptr as usize,
size_of::<T>()
));
}
if self.length % size_of::<T>() != 0 {
return Err(ChannelErrors::UnevenLength(self.length, size_of::<T>()));
}
let converted_type_id = TypeId::of::<T>();
if converted_type_id != self.type_id {
return Err(ChannelErrors::DifferentType(
self.type_id,
converted_type_id
));
}
Ok(())
}
pub unsafe fn alias(&self) -> &[u8] {
std::slice::from_raw_parts(self.ptr, self.length)
}
pub unsafe fn alias_mut(&mut self) -> &mut [u8] {
std::slice::from_raw_parts_mut(self.ptr, self.length)
}
}
impl Drop for Channel {
fn drop(&mut self) {
unsafe {
self.dealloc();
}
}
}
fn is_aligned<T>(ptr: *const u8) -> bool {
let size = core::mem::size_of::<T>();
(ptr as usize) & ((size) - 1) == 0
}
#[allow(unused_imports)]
mod tests {
use crate::channel::Channel;
#[test]
fn test_wrong_interpretation() {
let ch = Channel::new::<u8>();
assert!(ch.reinterpret_as::<u16>().is_err());
}
#[test]
fn test_correct_interpretation() {
let mut ch = Channel::new::<u16>();
ch.push(70_u16);
let expected = [70_u16];
assert_eq!(ch.reinterpret_as::<u16>().unwrap(), expected);
}
#[test]
fn test_clone_works() {
let mut ch = Channel::new::<u8>();
ch.extend::<u8>(&[10; 10]);
let ch2 = ch.clone();
assert_eq!(ch, ch2);
}
}