use crate::io::BufferedFile;
use crate::sys::Source;
use crate::Reactor;
use futures_lite::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, SeekFrom};
use pin_project_lite::pin_project;
use std::convert::TryInto;
use std::io;
use std::os::unix::io::AsRawFd;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
pin_project! {
#[derive(Debug)]
pub struct StreamReader {
file: BufferedFile,
file_pos: u64,
max_pos: u64,
source: Option<Source>,
buffer: Buffer,
}
}
pin_project! {
#[derive(Debug)]
pub struct Stdin {
source: Option<Source>,
buffer: Buffer,
}
}
pub fn stdin() -> Stdin {
Stdin {
source: None,
buffer: Buffer::new(128),
}
}
#[derive(Debug)]
pub struct StreamWriter {
file: BufferedFile,
file_pos: u64,
sync_on_close: bool,
source: Option<Source>,
buffer: Buffer,
}
#[derive(Debug)]
pub struct StreamReaderBuilder {
start: u64,
end: u64,
buffer_size: usize,
file: BufferedFile,
}
#[derive(Debug)]
pub struct StreamWriterBuilder {
buffer_size: usize,
sync_on_close: bool,
file: BufferedFile,
}
#[derive(Debug)]
struct Buffer {
max_buffer_size: usize,
buffer_pos: usize,
data: Vec<u8>,
}
impl Buffer {
fn new(max_buffer_size: usize) -> Buffer {
Buffer {
buffer_pos: 0,
data: Vec::with_capacity(max_buffer_size),
max_buffer_size,
}
}
fn replace_buffer(&mut self, buf: Vec<u8>) {
self.buffer_pos = 0;
self.data = buf;
}
fn remaining_unconsumed_bytes(&self) -> usize {
self.data.len() - self.buffer_pos
}
fn consume(&mut self, amt: usize) {
self.buffer_pos += amt;
}
fn copy_from_buffer(&mut self, buf: &[u8]) -> usize {
let max_size = self.max_buffer_size;
let copy_size = std::cmp::min(max_size - self.data.len(), buf.len());
self.data.extend_from_slice(&buf[0..copy_size]);
self.buffer_pos += copy_size;
copy_size
}
fn consumed_bytes(&mut self) -> Vec<u8> {
std::mem::replace(&mut self.data, Vec::new())
}
fn unconsumed_bytes(&self) -> &[u8] {
&self.data[self.buffer_pos..]
}
}
impl StreamReader {
pub async fn close(self) -> io::Result<()> {
self.file.close().await
}
fn new(builder: StreamReaderBuilder) -> StreamReader {
StreamReader {
file: builder.file,
file_pos: builder.start,
max_pos: builder.end,
source: None,
buffer: Buffer::new(builder.buffer_size),
}
}
}
impl StreamReaderBuilder {
#[must_use = "The builder must be built to be useful"]
pub fn new(file: BufferedFile) -> StreamReaderBuilder {
StreamReaderBuilder {
start: 0,
end: u64::MAX,
buffer_size: 4 << 10,
file,
}
}
pub fn with_start_pos(mut self, start: u64) -> Self {
self.start = start;
self
}
pub fn with_end_pos(mut self, end: u64) -> Self {
self.end = end;
self
}
pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = std::cmp::max(buffer_size, 1);
self
}
pub fn build(self) -> StreamReader {
StreamReader::new(self)
}
}
impl StreamWriterBuilder {
#[must_use = "The builder must be built to be useful"]
pub fn new(file: BufferedFile) -> StreamWriterBuilder {
StreamWriterBuilder {
buffer_size: 128 << 10,
sync_on_close: true,
file,
}
}
pub fn with_sync_on_close_disabled(mut self, flush_disabled: bool) -> Self {
self.sync_on_close = !flush_disabled;
self
}
pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = std::cmp::max(buffer_size, 1);
self
}
pub fn build(self) -> StreamWriter {
StreamWriter::new(self)
}
}
impl StreamWriter {
pub async fn close(mut self) -> io::Result<()> {
let bytes = self.buffer.consumed_bytes();
if !bytes.is_empty() {
let sz = self.file.write_at(bytes, self.file_pos).await?;
self.file_pos += sz as u64;
}
if self.sync_on_close {
self.file.fdatasync().await?;
}
self.file.close().await
}
fn new(builder: StreamWriterBuilder) -> StreamWriter {
StreamWriter {
file: builder.file,
sync_on_close: builder.sync_on_close,
file_pos: 0,
source: None,
buffer: Buffer::new(builder.buffer_size),
}
}
fn consume_flush_result(&mut self, mut source: Source) -> io::Result<()> {
let res = source.take_result().unwrap();
if res.is_ok() {
let mut buffer = source.extract_buffer();
self.file_pos += buffer.len() as u64;
buffer.truncate(0);
self.buffer.replace_buffer(buffer);
}
res.map(|_x| ())
}
fn flush_write_buffer(&mut self, waker: Waker) -> bool {
assert!(self.source.is_none());
let bytes = self.buffer.consumed_bytes();
if !bytes.is_empty() {
let source = Reactor::get().write_buffered(self.file.as_raw_fd(), bytes, self.file_pos);
source.add_waiter(waker);
self.source = Some(source);
true
} else {
false
}
}
}
macro_rules! do_poll {
( $self:expr, $cx:expr, $pos:expr ) => {
match $pos {
SeekFrom::Start(pos) => {
$self.file_pos = pos;
Poll::Ready(Ok(pos))
}
SeekFrom::Current(pos) => {
$self.file_pos = ($self.file_pos as i64 + pos) as u64;
Poll::Ready(Ok($self.file_pos))
}
SeekFrom::End(pos) => match $self.source.take() {
None => {
let source = Reactor::get().statx($self.file.as_raw_fd(), $self.file.path());
source.add_waiter($cx.waker().clone());
$self.source = Some(source);
Poll::Pending
}
Some(source) => {
let stype = source.extract_source_type();
let stat_buf: libc::statx = stype.try_into().unwrap();
let end = stat_buf.stx_size as i64;
$self.file_pos = (end + pos) as u64;
Poll::Ready(Ok($self.file_pos))
}
},
}
};
}
impl AsyncSeek for StreamReader {
fn poll_seek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
do_poll!(self, cx, pos)
}
}
impl AsyncSeek for StreamWriter {
fn poll_seek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
do_poll!(self, cx, pos)
}
}
impl AsyncRead for StreamReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let buffer = futures_lite::ready!(self.as_mut().poll_fill_buf(cx))?;
let bytes_read = std::cmp::min(buffer.len(), buf.len());
buf[0..bytes_read].copy_from_slice(&buffer[0..bytes_read]);
self.consume(bytes_read);
Poll::Ready(Ok(bytes_read))
}
}
impl AsyncBufRead for StreamReader {
fn poll_fill_buf<'a>(
mut self: Pin<&'a mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&'a [u8]>> {
match self.source.take() {
Some(mut source) => {
let res = source.take_result().unwrap();
match res {
Err(x) => Poll::Ready(Err(x)),
Ok(sz) => {
let mut buf = source.extract_buffer();
let old_pos = self.file_pos;
let new_pos = std::cmp::min(old_pos + sz as u64, self.max_pos);
let added_size = new_pos - old_pos;
self.file_pos += added_size;
buf.truncate(added_size as usize);
self.buffer.replace_buffer(buf);
let this = self.project();
Poll::Ready(Ok(&this.buffer.unconsumed_bytes()))
}
}
}
None => {
if self.buffer.remaining_unconsumed_bytes() > 0 {
let this = self.project();
Poll::Ready(Ok(&this.buffer.unconsumed_bytes()))
} else {
let file_pos = self.file_pos;
let fd = self.file.as_raw_fd();
let source =
Reactor::get().read_buffered(fd, file_pos, self.buffer.max_buffer_size);
source.add_waiter(cx.waker().clone());
self.source = Some(source);
Poll::Pending
}
}
}
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.buffer.consume(amt);
}
}
impl AsyncWrite for StreamWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if let Some(source) = self.source.take() {
if let Err(x) = self.consume_flush_result(source) {
return Poll::Ready(Err(x));
}
}
if !self.buffer.data.is_empty() {
let x = self.flush_write_buffer(cx.waker().clone());
assert_eq!(x, true);
Poll::Pending
} else {
Poll::Ready(Ok(self.buffer.copy_from_buffer(buf)))
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.source.take() {
None => match self.flush_write_buffer(cx.waker().clone()) {
true => Poll::Pending,
false => Poll::Ready(Ok(())),
},
Some(source) => Poll::Ready(self.consume_flush_result(source)),
}
}
#[allow(unreachable_code)]
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
panic!("Should never be called");
}
}
impl AsyncRead for Stdin {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let buffer = futures_lite::ready!(self.as_mut().poll_fill_buf(cx))?;
let bytes_read = std::cmp::min(buffer.len(), buf.len());
buf[0..bytes_read].copy_from_slice(&buffer[0..bytes_read]);
self.consume(bytes_read);
Poll::Ready(Ok(bytes_read))
}
}
impl AsyncBufRead for Stdin {
fn poll_fill_buf<'a>(
mut self: Pin<&'a mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<&'a [u8]>> {
match self.source.take() {
Some(mut source) => {
let res = source.take_result().unwrap();
match res {
Err(x) => Poll::Ready(Err(x)),
Ok(sz) => {
let mut buf = source.extract_buffer();
buf.truncate(sz);
self.buffer.replace_buffer(buf);
let this = self.project();
Poll::Ready(Ok(&this.buffer.unconsumed_bytes()))
}
}
}
None => {
if self.buffer.remaining_unconsumed_bytes() > 0 {
let this = self.project();
Poll::Ready(Ok(&this.buffer.unconsumed_bytes()))
} else {
let source = Reactor::get().read_buffered(
libc::STDIN_FILENO,
0,
self.buffer.max_buffer_size,
);
source.add_waiter(cx.waker().clone());
self.source = Some(source);
Poll::Pending
}
}
}
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.buffer.consume(amt);
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::io::dma_file::test::make_test_directories;
use futures_lite::{AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt, StreamExt};
use std::io::ErrorKind;
macro_rules! read_test {
( $name:ident, $dir:ident, $kind:ident, $file:ident, $file_size:ident: $size:tt, $code:block) => {
#[test]
fn $name() {
for dir in make_test_directories(stringify!($name)) {
let $dir = dir.path.clone();
let $kind = dir.kind;
test_executor!(async move {
let filename = $dir.join("testfile");
let new_file = BufferedFile::create(&filename)
.await
.expect("failed to create file");
if $size > 0 {
let mut buf = Vec::new();
for v in 0..$size {
buf.push(v as u8);
}
new_file.write_at(buf, 0).await.unwrap();
}
new_file.close().await.unwrap();
let $file = BufferedFile::open(&filename).await.unwrap();
let $file_size = $size;
$code
});
}
}
};
}
macro_rules! write_test {
( $name:ident, $dir:ident, $kind:ident, $file:ident, $code:block) => {
#[test]
fn $name() {
for dir in make_test_directories(stringify!($name)) {
let $dir = dir.path.clone();
let $kind = dir.kind;
test_executor!(async move {
let filename = $dir.join("testfile");
let $file = BufferedFile::create(&filename).await.unwrap();
$code
});
}
}
};
}
macro_rules! check_contents {
( $buf:expr, $start:expr ) => {
for (idx, i) in $buf.iter().enumerate() {
assert_eq!(*i, ($start + (idx as u64)) as u8);
}
};
}
read_test!(read_exact_empty_file, path, _k, file, _file_size: 0, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = [0u8; 128];
match reader.read_exact(&mut buf).await {
Err(x) => match x.kind() {
ErrorKind::UnexpectedEof => {},
_ => panic!("unexpected error"),
}
_ => panic!("unexpected success"),
}
reader.close().await.unwrap();
});
read_test!(read_exact_part_of_file, path, _k, file, _file_size: 4096, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = [0u8; 128];
reader.read_exact(&mut buf).await.unwrap();
check_contents!(buf, 0);
reader.close().await.unwrap();
});
read_test!(seek_start_and_read_exact, path, _k, file, _file_size: 4096, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = [0u8; 128];
reader.seek(SeekFrom::Start(10)).await.unwrap();
reader.read_exact(&mut buf).await.unwrap();
check_contents!(buf, 10);
reader.close().await.unwrap();
});
read_test!(seek_end_and_read_exact, path, _k, file, _file_size: 4096, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = [0u8; 96];
reader.seek(SeekFrom::End(-96)).await.unwrap();
reader.read_exact(&mut buf).await.unwrap();
check_contents!(buf, 4000);
reader.close().await.unwrap();
});
read_test!(read_to_end_from_start, path, _k, file, file_size: 4096, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = Vec::new();
let x = reader.read_to_end(&mut buf).await.unwrap();
assert_eq!(x, file_size);
check_contents!(buf, 0);
reader.close().await.unwrap();
});
read_test!(read_slice, path, _k, file, _file_size: 4096, {
let mut reader = StreamReaderBuilder::new(file)
.with_start_pos(2)
.with_end_pos(12)
.build();
let mut buf = Vec::new();
let x = reader.read_to_end(&mut buf).await.unwrap();
assert_eq!(x, 10);
check_contents!(buf, 2);
reader.close().await.unwrap();
});
read_test!(read_to_end_after_seek, path, _k, file, _file_size: 4096, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = Vec::new();
reader.seek(SeekFrom::Start(4000)).await.unwrap();
let x = reader.read_to_end(&mut buf).await.unwrap();
assert_eq!(x, 96);
check_contents!(buf, 4000);
reader.close().await.unwrap();
});
read_test!(read_until_empty_file, path, _k, file, _file_size: 0, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = Vec::new();
let x = reader.read_until(0xA, &mut buf).await.unwrap();
assert_eq!(x, 0);
reader.close().await.unwrap();
});
read_test!(read_until, path, _k, file, _file_size: 8192, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = Vec::new();
let x = reader.read_until(0xA, &mut buf).await.unwrap();
assert_eq!(x, 0xB);
for _ in 0..16 {
let x = reader.read_until(0xA, &mut buf).await.unwrap();
assert_eq!(x, 256);
}
reader.close().await.unwrap();
});
read_test!(read_until_eof, path, _k, file, file_size: 32, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = Vec::new();
let x = reader.read_until(255, &mut buf).await.unwrap();
assert_eq!(x, file_size);
reader.close().await.unwrap();
});
read_test!(read_line, path, _k, file, _file_size: 8192, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = String::new();
let x = reader.read_line(&mut buf).await.unwrap();
assert_eq!(x, 0xB);
reader.close().await.unwrap();
});
write_test!(lines, path, _k, file, {
let mut writer = StreamWriterBuilder::new(file).build();
writer.write_all(b"123\n123\n123\n").await.unwrap();
writer.close().await.unwrap();
let filename = path.join("testfile");
let file = BufferedFile::open(&filename).await.unwrap();
let reader = StreamReaderBuilder::new(file).build();
let mut lines = reader.lines();
let mut found = 0;
while let Some(line) = lines.next().await {
assert_eq!(line.unwrap(), "123");
found += 1;
}
assert_eq!(found, 3)
});
read_test!(split, path, _k, file, file_size: 4096, {
let reader = StreamReaderBuilder::new(file).build();
let split : Vec<Vec<u8>> = reader.split(255).try_collect().await.unwrap();
assert_eq!(split.len(), file_size / 256);
for line in split.iter() {
assert_eq!(line.len(), 255);
}
});
read_test!(mix_and_match_apis, path, _k, file, _file_size: 4096, {
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = Vec::new();
let x = reader.read_until(0xA, &mut buf).await.unwrap();
assert_eq!(x, 0xB);
let mut buf = [0u8; 10];
reader.read_exact(&mut buf).await.unwrap();
check_contents!(buf, 0xB);
reader.close().await.unwrap();
});
write_test!(write_simple, path, _k, file, {
let mut writer = StreamWriterBuilder::new(file).build();
writer.write_all(&[0, 1, 2, 3, 4]).await.unwrap();
writer.close().await.unwrap();
let filename = path.join("testfile");
let file = BufferedFile::open(&filename).await.unwrap();
assert_eq!(file.file_size().await.unwrap(), 5);
file.close().await.unwrap();
});
write_test!(seek_and_write, path, _k, file, {
let mut writer = StreamWriterBuilder::new(file).build();
writer.seek(SeekFrom::Start(10)).await.unwrap();
writer.write_all(&[0, 1, 2, 3, 4]).await.unwrap();
writer.close().await.unwrap();
let filename = path.join("testfile");
let file = BufferedFile::open(&filename).await.unwrap();
assert_eq!(file.file_size().await.unwrap(), 15);
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = [0u8; 15];
reader.read_exact(&mut buf).await.unwrap();
check_contents!(&buf[10..], 0);
for i in buf[..10].iter() {
assert_eq!(*i, 0);
}
reader.close().await.unwrap();
});
write_test!(nop_seek_write, path, _k, file, {
let mut writer = StreamWriterBuilder::new(file).build();
writer.write_all(&[0, 1, 2, 3, 4]).await.unwrap();
writer.seek(SeekFrom::End(0)).await.unwrap();
writer.write_all(&[5, 6, 7, 8, 9]).await.unwrap();
writer.close().await.unwrap();
let filename = path.join("testfile");
let file = BufferedFile::open(&filename).await.unwrap();
assert_eq!(file.file_size().await.unwrap(), 10);
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = [0u8; 10];
reader.read_exact(&mut buf).await.unwrap();
check_contents!(buf, 0);
reader.close().await.unwrap();
});
write_test!(write_and_flush, path, _k, file, {
let mut writer = StreamWriterBuilder::new(file).build();
writer.write_all(&[0, 1, 2, 3, 4]).await.unwrap();
writer.flush().await.unwrap();
let filename = path.join("testfile");
let file = BufferedFile::open(&filename).await.unwrap();
let mut reader = StreamReaderBuilder::new(file).build();
let mut buf = Vec::new();
let x = reader.read_to_end(&mut buf).await.unwrap();
assert_eq!(x, 5);
check_contents!(buf, 0);
reader.close().await.unwrap();
writer.close().await.unwrap();
});
}