use super::{
Errorable, MaybeAlloced, OutputRef, Packet, PacketPerms, Splittable, TransferredPacket,
};
use crate::error::Error;
use cglue::prelude::v1::*;
use core::marker::PhantomData;
use core::mem::{ManuallyDrop, MaybeUninit};
use core::ptr::NonNull;
use core::sync::atomic::*;
use tarc::BaseArc;
#[repr(C)]
pub struct BoundPacketView<Perms: PacketPerms> {
pub(crate) view: ManuallyDrop<PacketView<'static, Perms>>,
pub(crate) has_output: bool,
pub(crate) output: MaybeUninit<OutputRef<'static, Perms>>,
}
impl<Perms: PacketPerms> Drop for BoundPacketView<Perms> {
fn drop(&mut self) {
unsafe { self.output(None) }
}
}
impl<T: PacketPerms> Splittable<u64> for BoundPacketView<T> {
fn split_at(self, len: u64) -> (Self, Self) {
let mut this = ManuallyDrop::new(self);
let view = unsafe { ManuallyDrop::take(&mut this.view) };
let (v1, v2) = view.split_local(len);
let output = if this.has_output {
unsafe {
let output = this.output.assume_init_mut();
output.bound_views.fetch_add(1, Ordering::Release);
MaybeUninit::new(output.clone())
}
} else {
MaybeUninit::uninit()
};
(
Self {
view: ManuallyDrop::new(v1),
has_output: this.has_output,
output,
},
Self {
view: ManuallyDrop::new(v2),
has_output: this.has_output,
output: unsafe { core::ptr::read(&this.output) },
},
)
}
fn len(&self) -> u64 {
self.view.len()
}
}
impl<T: PacketPerms> Errorable for BoundPacketView<T> {
fn error(self, err: Error) {
let mut this = ManuallyDrop::new(self);
unsafe {
this.output(Some(err));
}
unsafe {
core::ptr::drop_in_place(&mut this.output);
}
}
}
impl<Perms: PacketPerms> BoundPacketView<Perms> {
pub unsafe fn forget(mut self) {
let pkt = ManuallyDrop::take(&mut self.view);
let waker = pkt.pkt().on_output(None);
debug_assert!(waker.is_none());
if let Some(waker) = waker {
waker.wake();
}
core::mem::forget(self)
}
pub unsafe fn extract_packet(&self, pos: u64, len: u64) -> Self {
let b = self.view.extract_packet(pos, len);
Self {
view: ManuallyDrop::new(b),
has_output: self.has_output,
output: unsafe { core::ptr::read(&self.output) },
}
}
unsafe fn output(&mut self, error: Option<Error>) {
let pkt = ManuallyDrop::take(&mut self.view);
let error = error.map(IntError::into_int_err);
let waker = pkt.pkt().on_output(error.map(|v| (self.view.start, v)));
if self.has_output {
let output = self.output.assume_init_read();
(output.vtbl.output)(output, pkt, error);
}
if let Some(waker) = waker {
waker.wake();
}
}
pub fn try_alloc(self) -> MaybeAlloced<Perms> {
Perms::try_alloc(self, 1)
}
pub unsafe fn unbound(&self) -> PacketView<'static, Perms> {
self.view.extract_packet(0, self.view.len())
}
pub unsafe fn transfer_data(
mut self,
input: Perms::ReverseDataType,
) -> TransferredPacket<Perms> {
if let Some(vtable) = self.view.pkt().vtbl.vtbl() {
(vtable.transfer_data_fn())(&mut self.view, input);
} else {
Perms::transfer_data_simple(&mut self.view, input);
}
TransferredPacket(self)
}
pub fn ptr(&self) -> *const u8 {
if self.view.pkt().vtbl.vtbl().is_some() {
core::ptr::null()
} else {
unsafe {
self.view
.pkt()
.simple_data_ptr()
.add(self.view.start as usize)
}
}
}
}
#[repr(C)]
pub struct PacketView<'a, Perms: PacketPerms> {
pub(crate) pkt: NonNull<Packet<Perms>>,
pub(crate) tag: u64,
pub(crate) start: u64,
pub(crate) end: u64,
phantom: PhantomData<&'a Packet<Perms>>,
}
unsafe impl<Perms: PacketPerms> Send for PacketView<'_, Perms> {}
unsafe impl<Perms: PacketPerms> Sync for PacketView<'_, Perms> {}
impl<Perms: PacketPerms> Drop for PacketView<'_, Perms> {
fn drop(&mut self) {
if self.tag & 1 != 0 {
unsafe {
BaseArc::decrement_strong_count(self.pkt.as_ptr());
}
}
}
}
impl<'a, Perms: PacketPerms> PacketView<'a, Perms> {
pub fn from_arc(pkt: BaseArc<Packet<Perms>>, tag: u64) -> Self {
Self::from_arc_ref(&pkt, tag)
}
pub fn from_arc_ref(pkt: &BaseArc<Packet<Perms>>, tag: u64) -> Self {
assert!(tag.leading_zeros() > 0);
let end = Perms::len(pkt);
unsafe { pkt.on_add_to_view() };
let pkt = NonNull::new(pkt.as_ptr().cast_mut()).unwrap();
unsafe { BaseArc::increment_strong_count(pkt.as_ptr()) };
Self {
pkt,
tag: tag << 1 | 1,
start: 0,
end,
phantom: PhantomData,
}
}
pub unsafe fn bind(self, output: Option<OutputRef<'a, Perms>>) -> BoundPacketView<Perms> {
if let Some(output) = &output {
output.bound_views.fetch_add(1, Ordering::Release);
}
let (has_output, output) = if let Some(output) = output {
(true, MaybeUninit::new(core::mem::transmute(output)))
} else {
(false, MaybeUninit::uninit())
};
BoundPacketView {
view: ManuallyDrop::new(core::mem::transmute(self)),
has_output,
output,
}
}
}
impl<'a, Perms: PacketPerms> PacketView<'a, Perms> {
pub fn from_ref(pkt: &'a Packet<Perms>, tag: u64) -> Self {
assert!(tag.leading_zeros() > 0);
let end = Perms::len(pkt);
unsafe { (*pkt).on_add_to_view() };
Self {
pkt: NonNull::new((pkt as *const Packet<Perms>).cast_mut()).unwrap(),
tag: tag << 1,
start: 0,
end,
phantom: PhantomData,
}
}
pub fn pkt(&self) -> &Packet<Perms> {
unsafe { &*self.pkt.as_ptr().cast_const() }
}
pub fn pkt_mut(&mut self) -> &mut Packet<Perms> {
unsafe { &mut *self.pkt.as_ptr() }
}
pub fn tag(&self) -> u64 {
self.tag >> 1
}
pub fn set_tag(&mut self, tag: u64) {
assert!(tag.leading_zeros() > 0);
self.tag = (tag << 1) | (self.tag & 1);
}
pub fn len(&self) -> u64 {
self.end - self.start
}
pub fn start(&self) -> u64 {
self.start
}
pub fn end(&self) -> u64 {
self.end
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn split_local(self, pos: u64) -> (Self, Self) {
assert!(pos < self.len());
self.pkt().rc_and_waker.inc_rc();
let Self {
pkt,
tag,
start,
end,
phantom,
} = self;
if tag & 1 != 0 {
unsafe {
BaseArc::increment_strong_count(pkt.as_ptr());
}
}
let ret = (
Self {
pkt,
tag,
start,
end: start + pos,
phantom,
},
Self {
pkt,
tag,
start: start + pos,
end,
phantom,
},
);
core::mem::forget(self);
ret
}
pub unsafe fn extract_packet(&self, offset: u64, len: u64) -> Self {
self.pkt().rc_and_waker.inc_rc();
let Self {
pkt, tag, start, ..
} = self;
assert!(offset <= self.len());
assert!(offset + len <= self.len());
if self.tag & 1 != 0 {
unsafe {
BaseArc::increment_strong_count(self.pkt.as_ptr());
}
}
Self {
pkt: *pkt,
tag: *tag,
start: start + offset,
end: start + offset + len,
phantom: PhantomData,
}
}
}