use crate::{FileDescriptor, io};
use alloc::boxed::Box;
use async_trait::async_trait;
use core::{
cell::RefCell,
cmp::Ordering,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use kona_preimage::{
Channel,
errors::{ChannelError, ChannelResult},
};
#[derive(Debug, Clone, Copy)]
pub struct FileChannel {
read_handle: FileDescriptor,
write_handle: FileDescriptor,
}
impl FileChannel {
pub const fn new(read_handle: FileDescriptor, write_handle: FileDescriptor) -> Self {
Self { read_handle, write_handle }
}
pub const fn read_handle(&self) -> FileDescriptor {
self.read_handle
}
pub const fn write_handle(&self) -> FileDescriptor {
self.write_handle
}
}
#[async_trait]
impl Channel for FileChannel {
async fn read(&self, buf: &mut [u8]) -> ChannelResult<usize> {
io::read(self.read_handle, buf).map_err(|_| ChannelError::Closed)
}
async fn read_exact(&self, buf: &mut [u8]) -> ChannelResult<usize> {
ReadFuture::new(*self, buf).await.map_err(|_| ChannelError::Closed)
}
async fn write(&self, buf: &[u8]) -> ChannelResult<usize> {
WriteFuture::new(*self, buf).await.map_err(|_| ChannelError::Closed)
}
}
struct ReadFuture<'a> {
channel: FileChannel,
buf: RefCell<&'a mut [u8]>,
read: usize,
}
impl<'a> ReadFuture<'a> {
#[allow(clippy::missing_const_for_fn)]
fn new(channel: FileChannel, buf: &'a mut [u8]) -> Self {
Self { channel, buf: RefCell::new(buf), read: 0 }
}
}
impl Future for ReadFuture<'_> {
type Output = ChannelResult<usize>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let mut buf = self.buf.borrow_mut();
let buf_len = buf.len();
let chunk_read = io::read(self.channel.read_handle, &mut buf[self.read..])
.map_err(|_| ChannelError::Closed)?;
drop(buf);
self.read += chunk_read;
match self.read.cmp(&buf_len) {
Ordering::Greater | Ordering::Equal => Poll::Ready(Ok(self.read)),
Ordering::Less => {
ctx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
struct WriteFuture<'a> {
channel: FileChannel,
buf: &'a [u8],
written: usize,
}
impl<'a> WriteFuture<'a> {
const fn new(channel: FileChannel, buf: &'a [u8]) -> Self {
Self { channel, buf, written: 0 }
}
}
impl Future for WriteFuture<'_> {
type Output = ChannelResult<usize>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
match io::write(self.channel.write_handle(), &self.buf[self.written..]) {
Ok(0) => Poll::Ready(Ok(self.written)), Ok(n) => {
self.written += n;
if self.written >= self.buf.len() {
return Poll::Ready(Ok(self.written));
}
ctx.waker().wake_by_ref();
Poll::Pending
}
Err(_) => Poll::Ready(Err(ChannelError::Closed)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_read_handle() {
let read_handle = FileDescriptor::StdIn;
let write_handle = FileDescriptor::StdOut;
let chan = FileChannel::new(read_handle, write_handle);
let ref_read_handle = chan.read_handle();
assert_eq!(read_handle, ref_read_handle);
}
}