use crate::rt::async_support::waitable::{WaitableOp, WaitableOperation};
use crate::rt::async_support::{AbiBuffer, DROPPED, ReturnCode};
use {
crate::rt::Cleanup,
std::{
alloc::Layout,
fmt,
future::Future,
pin::Pin,
ptr,
sync::atomic::{AtomicU32, Ordering::Relaxed},
task::{Context, Poll},
vec::Vec,
},
};
const MAX_LENGTH: usize = (1 << 28) - 1;
#[doc(hidden)]
pub unsafe trait StreamOps: Clone {
type Payload: 'static;
fn new(&mut self) -> u64;
fn elem_layout(&self) -> Layout;
fn native_abi_matches_canonical_abi(&self) -> bool;
fn contains_lists(&self) -> bool;
unsafe fn lower(&mut self, payload: Self::Payload, dst: *mut u8);
unsafe fn dealloc_lists(&mut self, dst: *mut u8);
unsafe fn lift(&mut self, dst: *mut u8) -> Self::Payload;
unsafe fn start_write(&mut self, stream: u32, val: *const u8, amt: usize) -> u32;
unsafe fn start_read(&mut self, stream: u32, val: *mut u8, amt: usize) -> u32;
unsafe fn cancel_read(&mut self, stream: u32) -> u32;
unsafe fn cancel_write(&mut self, stream: u32) -> u32;
unsafe fn drop_readable(&mut self, stream: u32);
unsafe fn drop_writable(&mut self, stream: u32);
}
#[doc(hidden)]
pub struct StreamVtable<T> {
pub layout: Layout,
pub lower: Option<unsafe fn(value: T, dst: *mut u8)>,
pub dealloc_lists: Option<unsafe fn(dst: *mut u8)>,
pub lift: Option<unsafe fn(dst: *mut u8) -> T>,
pub start_write: unsafe extern "C" fn(stream: u32, val: *const u8, amt: usize) -> u32,
pub start_read: unsafe extern "C" fn(stream: u32, val: *mut u8, amt: usize) -> u32,
pub cancel_write: unsafe extern "C" fn(stream: u32) -> u32,
pub cancel_read: unsafe extern "C" fn(stream: u32) -> u32,
pub drop_writable: unsafe extern "C" fn(stream: u32),
pub drop_readable: unsafe extern "C" fn(stream: u32),
pub new: unsafe extern "C" fn() -> u64,
}
unsafe impl<T: 'static> StreamOps for &StreamVtable<T> {
type Payload = T;
fn new(&mut self) -> u64 {
unsafe { (self.new)() }
}
fn elem_layout(&self) -> Layout {
self.layout
}
fn native_abi_matches_canonical_abi(&self) -> bool {
self.lift.is_none()
}
fn contains_lists(&self) -> bool {
self.dealloc_lists.is_some()
}
unsafe fn lower(&mut self, payload: Self::Payload, dst: *mut u8) {
if let Some(f) = self.lower {
unsafe { f(payload, dst) }
}
}
unsafe fn dealloc_lists(&mut self, dst: *mut u8) {
if let Some(f) = self.dealloc_lists {
unsafe { f(dst) }
}
}
unsafe fn lift(&mut self, dst: *mut u8) -> Self::Payload {
unsafe { (self.lift.unwrap())(dst) }
}
unsafe fn start_write(&mut self, stream: u32, val: *const u8, amt: usize) -> u32 {
unsafe { (self.start_write)(stream, val, amt) }
}
unsafe fn start_read(&mut self, stream: u32, val: *mut u8, amt: usize) -> u32 {
unsafe { (self.start_read)(stream, val, amt) }
}
unsafe fn cancel_read(&mut self, stream: u32) -> u32 {
unsafe { (self.cancel_read)(stream) }
}
unsafe fn cancel_write(&mut self, stream: u32) -> u32 {
unsafe { (self.cancel_write)(stream) }
}
unsafe fn drop_readable(&mut self, stream: u32) {
unsafe { (self.drop_readable)(stream) }
}
unsafe fn drop_writable(&mut self, stream: u32) {
unsafe { (self.drop_writable)(stream) }
}
}
pub unsafe fn stream_new<T>(
vtable: &'static StreamVtable<T>,
) -> (StreamWriter<T>, StreamReader<T>) {
unsafe { raw_stream_new(vtable) }
}
pub unsafe fn raw_stream_new<O>(mut ops: O) -> (RawStreamWriter<O>, RawStreamReader<O>)
where
O: StreamOps + Clone,
{
unsafe {
let handles = ops.new();
let reader = handles as u32;
let writer = (handles >> 32) as u32;
rtdebug!("stream.new() = [{writer}, {reader}]");
(
RawStreamWriter::new(writer, ops.clone()),
RawStreamReader::new(reader, ops),
)
}
}
pub type StreamWriter<T> = RawStreamWriter<&'static StreamVtable<T>>;
pub struct RawStreamWriter<O: StreamOps> {
handle: u32,
ops: O,
done: bool,
}
impl<O> RawStreamWriter<O>
where
O: StreamOps,
{
#[doc(hidden)]
pub unsafe fn new(handle: u32, ops: O) -> Self {
Self {
handle,
ops,
done: false,
}
}
pub fn handle(&self) -> u32 {
self.handle
}
pub fn write(&mut self, values: Vec<O::Payload>) -> RawStreamWrite<'_, O> {
self.write_buf(AbiBuffer::new(values, self.ops.clone()))
}
pub fn write_buf(&mut self, values: AbiBuffer<O>) -> RawStreamWrite<'_, O> {
RawStreamWrite {
op: WaitableOperation::new(StreamWriteOp { writer: self }, values),
}
}
pub async fn write_all(&mut self, values: Vec<O::Payload>) -> Vec<O::Payload> {
let (mut status, mut buf) = self.write(values).await;
while let StreamResult::Complete(_) = status {
if buf.remaining() == 0 {
break;
}
(status, buf) = self.write_buf(buf).await;
if status == StreamResult::Cancelled {
status = StreamResult::Complete(0);
}
}
assert!(buf.remaining() == 0 || matches!(status, StreamResult::Dropped));
buf.into_vec()
}
pub async fn write_one(&mut self, value: O::Payload) -> Option<O::Payload> {
self.write_all(std::vec![value]).await.pop()
}
}
impl<O> fmt::Debug for RawStreamWriter<O>
where
O: StreamOps,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamWriter")
.field("handle", &self.handle)
.finish()
}
}
impl<O> Drop for RawStreamWriter<O>
where
O: StreamOps,
{
fn drop(&mut self) {
rtdebug!("stream.drop-writable({})", self.handle);
unsafe {
self.ops.drop_writable(self.handle);
}
}
}
pub type StreamWrite<'a, T> = RawStreamWrite<'a, &'static StreamVtable<T>>;
pub struct RawStreamWrite<'a, O: StreamOps> {
op: WaitableOperation<StreamWriteOp<'a, O>>,
}
struct StreamWriteOp<'a, O: StreamOps> {
writer: &'a mut RawStreamWriter<O>,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum StreamResult {
Complete(usize),
Dropped,
Cancelled,
}
unsafe impl<'a, O> WaitableOp for StreamWriteOp<'a, O>
where
O: StreamOps,
{
type Start = AbiBuffer<O>;
type InProgress = AbiBuffer<O>;
type Result = (StreamResult, AbiBuffer<O>);
type Cancel = (StreamResult, AbiBuffer<O>);
fn start(&mut self, buf: Self::Start) -> (u32, Self::InProgress) {
if self.writer.done {
return (DROPPED, buf);
}
let (ptr, len) = buf.abi_ptr_and_len();
let code = unsafe {
self.writer
.ops
.start_write(self.writer.handle, ptr, len.min(MAX_LENGTH))
};
rtdebug!(
"stream.write({}, {ptr:?}, {len}) = {code:#x}",
self.writer.handle
);
(code, buf)
}
fn start_cancelled(&mut self, buf: Self::Start) -> Self::Cancel {
(StreamResult::Cancelled, buf)
}
fn in_progress_update(
&mut self,
mut buf: Self::InProgress,
code: u32,
) -> Result<Self::Result, Self::InProgress> {
match ReturnCode::decode(code) {
ReturnCode::Blocked => Err(buf),
ReturnCode::Dropped(0) => Ok((StreamResult::Dropped, buf)),
ReturnCode::Cancelled(0) => Ok((StreamResult::Cancelled, buf)),
code @ (ReturnCode::Completed(amt)
| ReturnCode::Dropped(amt)
| ReturnCode::Cancelled(amt)) => {
let amt = amt.try_into().unwrap();
buf.advance(amt);
if let ReturnCode::Dropped(_) = code {
self.writer.done = true;
}
Ok((StreamResult::Complete(amt), buf))
}
}
}
fn in_progress_waitable(&mut self, _: &Self::InProgress) -> u32 {
self.writer.handle
}
fn in_progress_cancel(&mut self, _: &mut Self::InProgress) -> u32 {
let code = unsafe { self.writer.ops.cancel_write(self.writer.handle) };
rtdebug!("stream.cancel-write({}) = {code:#x}", self.writer.handle);
code
}
fn result_into_cancel(&mut self, result: Self::Result) -> Self::Cancel {
result
}
}
impl<O: StreamOps> Future for RawStreamWrite<'_, O> {
type Output = (StreamResult, AbiBuffer<O>);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.pin_project().poll_complete(cx)
}
}
impl<'a, O: StreamOps> RawStreamWrite<'a, O> {
fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation<StreamWriteOp<'a, O>>> {
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().op) }
}
pub fn cancel(self: Pin<&mut Self>) -> (StreamResult, AbiBuffer<O>) {
self.pin_project().cancel()
}
}
pub type StreamReader<T> = RawStreamReader<&'static StreamVtable<T>>;
pub struct RawStreamReader<O: StreamOps> {
handle: AtomicU32,
ops: O,
done: bool,
}
impl<O: StreamOps> fmt::Debug for RawStreamReader<O> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamReader")
.field("handle", &self.handle)
.finish()
}
}
impl<O: StreamOps> RawStreamReader<O> {
#[doc(hidden)]
pub fn new(handle: u32, ops: O) -> Self {
Self {
handle: AtomicU32::new(handle),
ops,
done: false,
}
}
#[doc(hidden)]
pub fn take_handle(&self) -> u32 {
let ret = self.opt_handle().unwrap();
self.handle.store(u32::MAX, Relaxed);
ret
}
pub fn handle(&self) -> u32 {
self.opt_handle().unwrap()
}
fn opt_handle(&self) -> Option<u32> {
match self.handle.load(Relaxed) {
u32::MAX => None,
other => Some(other),
}
}
pub fn read(&mut self, buf: Vec<O::Payload>) -> RawStreamRead<'_, O> {
RawStreamRead {
op: WaitableOperation::new(StreamReadOp { reader: self }, buf),
}
}
pub async fn next(&mut self) -> Option<O::Payload> {
let (_result, mut buf) = self.read(Vec::with_capacity(1)).await;
buf.pop()
}
pub async fn collect(mut self) -> Vec<O::Payload> {
let mut ret = Vec::new();
loop {
if ret.len() == ret.capacity() {
ret.reserve(1);
}
let (status, buf) = self.read(ret).await;
ret = buf;
match status {
StreamResult::Complete(_) => {}
StreamResult::Dropped => break,
StreamResult::Cancelled => unreachable!(),
}
}
ret
}
}
impl<O: StreamOps> Drop for RawStreamReader<O> {
fn drop(&mut self) {
let Some(handle) = self.opt_handle() else {
return;
};
unsafe {
rtdebug!("stream.drop-readable({})", handle);
self.ops.drop_readable(handle);
}
}
}
pub type StreamRead<'a, T> = RawStreamRead<'a, &'static StreamVtable<T>>;
pub struct RawStreamRead<'a, O: StreamOps> {
op: WaitableOperation<StreamReadOp<'a, O>>,
}
struct StreamReadOp<'a, O: StreamOps> {
reader: &'a mut RawStreamReader<O>,
}
unsafe impl<'a, O: StreamOps> WaitableOp for StreamReadOp<'a, O> {
type Start = Vec<O::Payload>;
type InProgress = (Vec<O::Payload>, Option<Cleanup>);
type Result = (StreamResult, Vec<O::Payload>);
type Cancel = (StreamResult, Vec<O::Payload>);
fn start(&mut self, mut buf: Self::Start) -> (u32, Self::InProgress) {
if self.reader.done {
return (DROPPED, (buf, None));
}
let cap = buf.spare_capacity_mut();
let ptr;
let cleanup;
if self.reader.ops.native_abi_matches_canonical_abi() {
ptr = cap.as_mut_ptr().cast();
cleanup = None;
} else {
let elem_layout = self.reader.ops.elem_layout();
let layout =
Layout::from_size_align(elem_layout.size() * cap.len(), elem_layout.align())
.unwrap();
(ptr, cleanup) = Cleanup::new(layout);
}
let code = unsafe {
self.reader
.ops
.start_read(self.reader.handle(), ptr, cap.len().min(MAX_LENGTH))
};
rtdebug!(
"stream.read({}, {ptr:?}, {}) = {code:#x}",
self.reader.handle(),
cap.len()
);
(code, (buf, cleanup))
}
fn start_cancelled(&mut self, buf: Self::Start) -> Self::Cancel {
(StreamResult::Cancelled, buf)
}
fn in_progress_update(
&mut self,
(mut buf, cleanup): Self::InProgress,
code: u32,
) -> Result<Self::Result, Self::InProgress> {
match ReturnCode::decode(code) {
ReturnCode::Blocked => Err((buf, cleanup)),
ReturnCode::Dropped(0) => Ok((StreamResult::Dropped, buf)),
ReturnCode::Cancelled(0) => Ok((StreamResult::Cancelled, buf)),
code @ (ReturnCode::Completed(amt)
| ReturnCode::Dropped(amt)
| ReturnCode::Cancelled(amt)) => {
let amt = usize::try_from(amt).unwrap();
let cur_len = buf.len();
assert!(amt <= buf.capacity() - cur_len);
if self.reader.ops.native_abi_matches_canonical_abi() {
unsafe {
buf.set_len(cur_len + amt);
}
} else {
let mut ptr = cleanup
.as_ref()
.map(|c| c.ptr.as_ptr())
.unwrap_or(ptr::null_mut());
for _ in 0..amt {
unsafe {
buf.push(self.reader.ops.lift(ptr));
ptr = ptr.add(self.reader.ops.elem_layout().size());
}
}
}
drop(cleanup);
if let ReturnCode::Dropped(_) = code {
self.reader.done = true;
}
Ok((StreamResult::Complete(amt), buf))
}
}
}
fn in_progress_waitable(&mut self, _: &Self::InProgress) -> u32 {
self.reader.handle()
}
fn in_progress_cancel(&mut self, _: &mut Self::InProgress) -> u32 {
let code = unsafe { self.reader.ops.cancel_read(self.reader.handle()) };
rtdebug!("stream.cancel-read({}) = {code:#x}", self.reader.handle());
code
}
fn result_into_cancel(&mut self, result: Self::Result) -> Self::Cancel {
result
}
}
impl<O: StreamOps> Future for RawStreamRead<'_, O> {
type Output = (StreamResult, Vec<O::Payload>);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.pin_project().poll_complete(cx)
}
}
impl<'a, O> RawStreamRead<'a, O>
where
O: StreamOps,
{
fn pin_project(self: Pin<&mut Self>) -> Pin<&mut WaitableOperation<StreamReadOp<'a, O>>> {
unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().op) }
}
pub fn cancel(self: Pin<&mut Self>) -> (StreamResult, Vec<O::Payload>) {
self.pin_project().cancel()
}
}