use crate::common::SequenceConfig;
use crate::context::Context;
use crate::{IoOutput, Output, Writable, WritableSeq};
use std::convert::Infallible;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::pin::pin;
use std::task::{Poll, Waker};
pub struct WritableFromFunction<F>(pub F);
impl<F, T, W, O> SequenceConfig<T, O> for WritableFromFunction<F>
where
O: Output,
F: Fn(&T) -> W,
W: Writable<O>,
{
async fn write_datum(&self, datum: &T, output: &mut O) -> Result<(), O::Error> {
let writable = (&self.0)(datum);
writable.write_to(output).await
}
}
pub struct InMemoryIo<'s>(pub &'s mut String);
impl IoOutput for InMemoryIo<'_> {
type Error = Infallible;
async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
self.0.push_str(value);
Ok(())
}
}
pub struct InMemoryOutput<Ctx, Err = Infallible> {
buf: String,
context: Ctx,
error_type: PhantomData<fn(Infallible) -> Err>,
}
impl<Ctx, Err> InMemoryOutput<Ctx, Err> {
pub fn new(context: Ctx) -> Self {
Self {
buf: String::new(),
context,
error_type: PhantomData,
}
}
}
impl<Ctx, Err> Output for InMemoryOutput<Ctx, Err>
where
Ctx: Context,
Err: From<Infallible>,
{
type Io<'b>
= InMemoryIo<'b>
where
Self: 'b;
type Ctx = Ctx;
type Error = Err;
async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
self.buf.push_str(value);
Ok(())
}
fn split(&mut self) -> (Self::Io<'_>, &Self::Ctx) {
(InMemoryIo(&mut self.buf), &self.context)
}
fn context(&self) -> &Self::Ctx {
&self.context
}
}
impl<Ctx, Err> InMemoryOutput<Ctx, Err>
where
Ctx: Context,
Err: From<Infallible>,
{
pub fn try_print_output<W>(context: Ctx, writable: &W) -> Result<String, Err>
where
W: Writable<Self>,
{
let mut output = Self::new(context);
let result = output.print_output_impl(writable);
result.map(|()| output.buf)
}
fn print_output_impl<W>(&mut self, writable: &W) -> Result<(), Err>
where
W: Writable<Self>,
{
let future = pin!(writable.write_to(self));
match future.poll(&mut std::task::Context::from_waker(Waker::noop())) {
Poll::Pending => panic!("Expected a complete future"),
Poll::Ready(result) => result,
}
}
}
impl<Ctx> InMemoryOutput<Ctx>
where
Ctx: Context,
{
pub fn print_output<W>(context: Ctx, writable: &W) -> String
where
W: Writable<Self>,
{
Self::try_print_output(context, writable).unwrap_or_else(|e| match e {})
}
}
pub struct IntoStringIter<Ctx, Seq, Err = Infallible> {
context: Ctx,
sequence: Seq,
error_type: PhantomData<fn(Infallible) -> Err>,
}
impl<Ctx, Seq, Err> IntoStringIter<Ctx, Seq, Err> {
pub fn new(context: Ctx, sequence: Seq) -> Self {
Self {
context,
sequence,
error_type: PhantomData,
}
}
}
impl<Ctx, Seq, Err> Clone for IntoStringIter<Ctx, Seq, Err>
where
Ctx: Clone,
Seq: Clone,
{
fn clone(&self) -> Self {
Self {
context: self.context.clone(),
sequence: self.sequence.clone(),
error_type: PhantomData,
}
}
}
impl<Ctx, Seq, Err> Debug for IntoStringIter<Ctx, Seq, Err>
where
Ctx: Debug,
Seq: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IntoStringIter")
.field("context", &self.context)
.field("sequence", &self.sequence)
.field("error_type", &std::any::type_name::<Err>())
.finish()
}
}
impl<Ctx, Seq, Err> IntoIterator for IntoStringIter<Ctx, Seq, Err>
where
Ctx: Context,
Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
Err: From<Infallible>,
{
type Item = Result<String, Err>;
type IntoIter = ToStringIter<Ctx, Seq, Err>;
fn into_iter(self) -> Self::IntoIter {
ToStringIter(string_iter::StringIter::new(self.context, self.sequence))
}
}
pub struct ToStringIter<Ctx, Seq, Err = Infallible>(string_iter::StringIter<Ctx, Seq, Err>);
impl<Ctx, Seq, Err> Debug for ToStringIter<Ctx, Seq, Err>
where
Err: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToStringIter")
.field("inner", &self.0)
.finish()
}
}
impl<Ctx, Seq, Err> Iterator for ToStringIter<Ctx, Seq, Err>
where
Ctx: Context,
Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
Err: From<Infallible>,
{
type Item = Result<String, Err>;
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}
}
mod string_iter {
use crate::context::Context;
use crate::util::InMemoryOutput;
use crate::{SequenceAccept, Writable, WritableSeq};
use std::cell::Cell;
use std::convert::Infallible;
use std::fmt::Debug;
use std::future::poll_fn;
use std::mem::{ManuallyDrop, MaybeUninit};
use std::ops::DerefMut;
use std::pin::Pin;
use std::ptr::NonNull;
use std::task::{Poll, Waker};
use std::{mem, ptr};
pub struct StringIter<Ctx, Seq, Err> {
progressor: NonNull<Progressor<Ctx, Seq, Err>>,
seq_error_in_pipe: Option<Err>,
finished: bool,
}
impl<Ctx, Seq, Err> Debug for StringIter<Ctx, Seq, Err>
where
Err: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let progressor = self.progressor.as_ptr();
let buffer = unsafe {
&*(&raw const (*progressor).buffer)
};
f.debug_struct("StringIter")
.field(
"marker",
&(&std::any::type_name::<Ctx>(), &std::any::type_name::<Seq>()),
)
.field("progressor.buffer", buffer)
.field("seq_error_in_pipe", &self.seq_error_in_pipe)
.field("finished", &self.finished)
.finish()
}
}
impl<Ctx, Seq, Err> Drop for StringIter<Ctx, Seq, Err> {
fn drop(&mut self) {
Progressor::deallocate(self.progressor)
}
}
type Progressor<Ctx, Seq, Err> =
RawProgressor<Ctx, Seq, Err, dyn Future<Output = Result<(), Err>>>;
struct RawProgressor<Ctx, Seq, Err, Fut: ?Sized> {
buffer: ItemBuffer<Err>,
vault: RawProgressorVault<Ctx, Seq, Err>,
future: ManuallyDrop<Fut>,
}
struct RawProgressorVault<Ctx, Seq, Err> {
acceptor: SeqAccept<Ctx, Err>,
sequence: Seq,
}
impl<Ctx, Seq, Err> StringIter<Ctx, Seq, Err>
where
Ctx: Context,
Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
Err: From<Infallible>,
{
pub fn new(context: Ctx, sequence: Seq) -> Self {
let ptr = Self::make_raw_progressor(context, sequence, |vault| {
WritableSeq::for_each(&vault.sequence, &mut vault.acceptor)
});
Self {
progressor: ptr,
seq_error_in_pipe: None,
finished: false,
}
}
fn make_raw_progressor<'f, MakeFut, Fut>(
context: Ctx,
sequence: Seq,
make_fut: MakeFut,
) -> NonNull<Progressor<Ctx, Seq, Err>>
where
Fut: Future<Output = Result<(), Err>> + 'f,
MakeFut: FnOnce(&'f mut RawProgressorVault<Ctx, Seq, Err>) -> Fut,
Ctx: 'f,
Seq: 'f,
Err: 'f,
{
let allocated = Box::new(MaybeUninit::<RawProgressor<Ctx, Seq, Err, Fut>>::uninit());
unsafe {
let fields_ptr = Box::into_raw(allocated);
let fields_ptr = (&mut *fields_ptr).as_mut_ptr();
let buffer_ptr = &raw mut (*fields_ptr).buffer;
ptr::write(buffer_ptr, ItemBuffer::default());
let vault_ptr = &raw mut (*fields_ptr).vault;
ptr::write(
vault_ptr,
RawProgressorVault {
acceptor: SeqAccept {
output: InMemoryOutput::new(context),
buffer: buffer_ptr,
},
sequence,
},
);
let future_ptr = &raw mut (*fields_ptr).future;
ptr::write(future_ptr, ManuallyDrop::new(make_fut(&mut *vault_ptr)));
let fields_ptr = mem::transmute::<
*mut RawProgressor<_, _, _, dyn Future<Output = Result<(), Err>> + 'f>,
*mut RawProgressor<_, _, _, dyn Future<Output = Result<(), Err>> + 'static>,
>(fields_ptr);
NonNull::<Progressor<Ctx, Seq, Err>>::new_unchecked(fields_ptr)
}
}
}
impl<Ctx, Seq, Err> Progressor<Ctx, Seq, Err> {
fn deallocate(ptr: NonNull<Self>) {
let ptr = ptr.as_ptr();
unsafe {
{
let future_ptr = &raw mut (*ptr).future;
let future_to_drop = &mut *future_ptr;
ManuallyDrop::drop(future_to_drop);
}
let _allocation = Box::from_raw(ptr);
}
}
}
impl<Ctx, Seq, Err> Iterator for StringIter<Ctx, Seq, Err> {
type Item = Result<String, Err>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished {
return None;
}
if let Some(error) = mem::take(&mut self.seq_error_in_pipe) {
self.finished = true;
return Some(Err(error));
}
let fields_ptr = self.progressor.as_ptr();
let (poll_outcome, item) = unsafe {
let future_ptr = &raw mut (*fields_ptr).future;
let pinned_future = Pin::new_unchecked((&mut *future_ptr).deref_mut());
let poll_outcome =
pinned_future.poll(&mut std::task::Context::from_waker(Waker::noop()));
let buffer_ptr = &raw const (*fields_ptr).buffer;
let item = (&*buffer_ptr).extract();
(poll_outcome, item)
};
match poll_outcome {
Poll::Pending => {
assert!(
item.is_some(),
"Extraneous async computations (writable should complete regularly)"
);
}
Poll::Ready(Err(seq_error)) => {
if item.is_some() {
self.seq_error_in_pipe = Some(seq_error);
} else {
self.finished = true;
return Some(Err(seq_error));
}
}
Poll::Ready(Ok(())) => {
self.finished = true;
}
};
item
}
}
struct SeqAccept<Ctx, Err> {
output: InMemoryOutput<Ctx, Err>,
buffer: *const ItemBuffer<Err>,
}
impl<Ctx, Err> SequenceAccept<InMemoryOutput<Ctx, Err>> for SeqAccept<Ctx, Err>
where
Ctx: Context,
Err: From<Infallible>,
{
async fn accept<W>(&mut self, writable: &W) -> Result<(), Err>
where
W: Writable<InMemoryOutput<Ctx, Err>>,
{
poll_fn(|_| {
let buffer = unsafe {
&*self.buffer
};
if !buffer.has_space() {
return Poll::Pending;
}
let result = self.output.print_output_impl(writable);
let string = mem::take(&mut self.output.buf);
buffer.set_new(result.map(|()| string));
Poll::Ready(Ok(()))
})
.await
}
}
struct ItemBuffer<Err>(Cell<Option<Result<String, Err>>>);
impl<Err> Default for ItemBuffer<Err> {
fn default() -> Self {
Self(Cell::new(None))
}
}
impl<Err> ItemBuffer<Err> {
fn inspect<F, R>(&self, op: F) -> R
where
F: FnOnce(&Option<Result<String, Err>>) -> R,
{
let current = self.0.take();
let result = op(¤t);
self.0.set(current);
result
}
}
impl<Err> Debug for ItemBuffer<Err>
where
Err: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.inspect(|current| {
f.debug_struct("ItemBuffer")
.field("current", current)
.finish()
})
}
}
impl<Err> ItemBuffer<Err> {
fn has_space(&self) -> bool {
self.inspect(Option::is_none)
}
fn set_new(&self, value: Result<String, Err>) {
self.0.set(Some(value));
}
fn extract(&self) -> Option<Result<String, Err>> {
self.0.take()
}
}
}
#[cfg(test)]
mod tests {
use crate::common::{CombinedSeq, NoOpSeq, SingularSeq, Str, StrArrSeq};
use crate::context::EmptyContext;
use crate::util::IntoStringIter;
use crate::{Output, SequenceAccept, Writable, WritableSeq};
use std::convert::Infallible;
#[test]
fn sequence_iterator() {
let sequence = StrArrSeq(&["One", "Two", "Three"]);
let iterator = IntoStringIter::new(EmptyContext, sequence);
let iterator = iterator.into_iter();
let expected = &["One", "Two", "Three"].map(|s| Ok::<_, Infallible>(String::from(s)));
assert_eq!(iterator.collect::<Vec<_>>(), Vec::from(expected));
}
#[test]
fn sequence_iterator_empty() {
let sequence = NoOpSeq;
let iterator: IntoStringIter<_, _> = IntoStringIter::new(EmptyContext, sequence);
let iterator = iterator.into_iter();
assert!(iterator.collect::<Vec<_>>().is_empty());
}
#[derive(Clone)]
struct SequenceWithError<Seq> {
emit_before: bool,
seq: Seq,
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct SampleError;
impl From<Infallible> for SampleError {
fn from(value: Infallible) -> Self {
match value {}
}
}
impl<O, Seq> WritableSeq<O> for SequenceWithError<Seq>
where
O: Output<Error = SampleError>,
Seq: WritableSeq<O>,
{
async fn for_each<S>(&self, sink: &mut S) -> Result<(), O::Error>
where
S: SequenceAccept<O>,
{
if !self.emit_before {
self.seq.for_each(sink).await?;
}
Err(SampleError)
}
}
#[test]
fn sequence_iterator_seq_error() {
let sequence = SequenceWithError {
emit_before: true,
seq: StrArrSeq(&["Will", "Never", "Be", "Seen"]),
};
let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
assert!(iterator.into_iter().find(Result::is_ok).is_none());
}
#[test]
fn sequence_iterator_seq_error_afterward() {
let sequence = SequenceWithError {
emit_before: false,
seq: StrArrSeq(&["Data", "More"]),
};
let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
assert_eq!(
vec![
Ok(String::from("Data")),
Ok(String::from("More")),
Err(SampleError)
],
iterator.into_iter().collect::<Vec<_>>()
);
}
#[test]
fn sequence_iterator_seq_error_in_between() {
let sequence = CombinedSeq(
StrArrSeq(&["One", "Two"]),
SequenceWithError {
emit_before: true,
seq: SingularSeq(Str("Final")),
},
);
let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
assert_eq!(
vec![
Ok(String::from("One")),
Ok(String::from("Two")),
Err(SampleError)
],
iterator.into_iter().collect::<Vec<_>>()
);
}
#[test]
fn sequence_iterator_seq_error_empty() {
let sequence = SequenceWithError {
emit_before: true,
seq: NoOpSeq,
};
let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
assert_eq!(
vec![Err(SampleError)],
iterator.into_iter().collect::<Vec<_>>()
);
}
#[derive(Clone, Debug)]
struct ProduceError;
impl<O> Writable<O> for ProduceError
where
O: Output<Error = SampleError>,
{
async fn write_to(&self, _: &mut O) -> Result<(), O::Error> {
Err(SampleError)
}
}
#[test]
fn sequence_iterator_write_error() {
let sequence = CombinedSeq(SingularSeq(ProduceError), StrArrSeq(&["Is", "Seen"]));
let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
assert_eq!(
vec![
Err(SampleError),
Ok(String::from("Is")),
Ok(String::from("Seen")),
],
iterator.into_iter().collect::<Vec<_>>()
);
}
#[test]
fn sequence_iterator_write_error_afterward() {
let sequence = CombinedSeq(StrArrSeq(&["Data", "MoreData"]), SingularSeq(ProduceError));
let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
assert_eq!(
vec![
Ok(String::from("Data")),
Ok(String::from("MoreData")),
Err(SampleError)
],
iterator.into_iter().collect::<Vec<_>>()
);
}
#[test]
fn sequence_iterator_write_error_in_between() {
let sequence = CombinedSeq(
StrArrSeq(&["Data", "Adjacent"]),
CombinedSeq(SingularSeq(ProduceError), SingularSeq(Str("Final"))),
);
let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
assert_eq!(
vec![
Ok(String::from("Data")),
Ok(String::from("Adjacent")),
Err(SampleError),
Ok(String::from("Final"))
],
iterator.into_iter().collect::<Vec<_>>()
);
}
}