1use crate::common::SequenceConfig;
18use crate::context::Context;
19use crate::{IoOutput, Output, Writable, WritableSeq};
20use std::convert::Infallible;
21use std::fmt::Debug;
22use std::marker::PhantomData;
23use std::pin::pin;
24use std::task::{Poll, Waker};
25
26pub struct WritableFromFunction<F>(pub F);
29
30impl<F, T, W, O> SequenceConfig<T, O> for WritableFromFunction<F>
31where
32    O: Output,
33    F: Fn(&T) -> W,
34    W: Writable<O>,
35{
36    async fn write_datum(&self, datum: &T, output: &mut O) -> Result<(), O::Error> {
37        let writable = (&self.0)(datum);
38        writable.write_to(output).await
39    }
40}
41
42pub struct InMemoryIo<'s>(pub &'s mut String);
45
46impl IoOutput for InMemoryIo<'_> {
47    type Error = Infallible;
48
49    async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
50        self.0.push_str(value);
51        Ok(())
52    }
53}
54
55pub struct InMemoryOutput<Ctx, Err = Infallible> {
65    buf: String,
66    context: Ctx,
67    error_type: PhantomData<fn(Infallible) -> Err>,
68}
69
70impl<Ctx, Err> InMemoryOutput<Ctx, Err> {
71    pub fn new(context: Ctx) -> Self {
72        Self {
73            buf: String::new(),
74            context,
75            error_type: PhantomData,
76        }
77    }
78}
79
80impl<Ctx, Err> Output for InMemoryOutput<Ctx, Err>
81where
82    Ctx: Context,
83    Err: From<Infallible>,
84{
85    type Io<'b>
86        = InMemoryIo<'b>
87    where
88        Self: 'b;
89    type Ctx = Ctx;
90    type Error = Err;
91
92    async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
93        self.buf.push_str(value);
94        Ok(())
95    }
96
97    fn split(&mut self) -> (Self::Io<'_>, &Self::Ctx) {
98        (InMemoryIo(&mut self.buf), &self.context)
99    }
100
101    fn context(&self) -> &Self::Ctx {
102        &self.context
103    }
104}
105
106impl<Ctx, Err> InMemoryOutput<Ctx, Err>
107where
108    Ctx: Context,
109    Err: From<Infallible>,
110{
111    pub fn try_print_output<W>(context: Ctx, writable: &W) -> Result<String, Err>
120    where
121        W: Writable<Self>,
122    {
123        let mut output = Self::new(context);
124        let result = output.print_output_impl(writable);
125        result.map(|()| output.buf)
126    }
127
128    fn print_output_impl<W>(&mut self, writable: &W) -> Result<(), Err>
129    where
130        W: Writable<Self>,
131    {
132        let future = pin!(writable.write_to(self));
133        match future.poll(&mut std::task::Context::from_waker(Waker::noop())) {
134            Poll::Pending => panic!("Expected a complete future"),
135            Poll::Ready(result) => result,
136        }
137    }
138}
139
140impl<Ctx> InMemoryOutput<Ctx>
141where
142    Ctx: Context,
143{
144    pub fn print_output<W>(context: Ctx, writable: &W) -> String
152    where
153        W: Writable<Self>,
154    {
155        Self::try_print_output(context, writable).unwrap_or_else(|e| match e {})
156    }
157}
158
159pub struct IntoStringIter<Ctx, Seq, Err = Infallible> {
187    context: Ctx,
188    sequence: Seq,
189    error_type: PhantomData<fn(Infallible) -> Err>,
190}
191
192impl<Ctx, Seq, Err> IntoStringIter<Ctx, Seq, Err> {
193    pub fn new(context: Ctx, sequence: Seq) -> Self {
194        Self {
195            context,
196            sequence,
197            error_type: PhantomData,
198        }
199    }
200}
201
202impl<Ctx, Seq, Err> Clone for IntoStringIter<Ctx, Seq, Err>
203where
204    Ctx: Clone,
205    Seq: Clone,
206{
207    fn clone(&self) -> Self {
208        Self {
209            context: self.context.clone(),
210            sequence: self.sequence.clone(),
211            error_type: PhantomData,
212        }
213    }
214}
215
216impl<Ctx, Seq, Err> Debug for IntoStringIter<Ctx, Seq, Err>
217where
218    Ctx: Debug,
219    Seq: Debug,
220{
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        f.debug_struct("IntoStringIter")
223            .field("context", &self.context)
224            .field("sequence", &self.sequence)
225            .field("error_type", &std::any::type_name::<Err>())
226            .finish()
227    }
228}
229
230impl<Ctx, Seq, Err> IntoIterator for IntoStringIter<Ctx, Seq, Err>
231where
232    Ctx: Context,
233    Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
234    Err: From<Infallible>,
235{
236    type Item = Result<String, Err>;
237    type IntoIter = ToStringIter<Ctx, Seq, Err>;
238    fn into_iter(self) -> Self::IntoIter {
239        ToStringIter(string_iter::StringIter::new(self.context, self.sequence))
240    }
241}
242
243pub struct ToStringIter<Ctx, Seq, Err = Infallible>(string_iter::StringIter<Ctx, Seq, Err>);
245
246impl<Ctx, Seq, Err> Debug for ToStringIter<Ctx, Seq, Err>
247where
248    Err: Debug,
249{
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        f.debug_struct("ToStringIter")
252            .field("inner", &self.0)
253            .finish()
254    }
255}
256
257impl<Ctx, Seq, Err> Iterator for ToStringIter<Ctx, Seq, Err>
258where
259    Ctx: Context,
260    Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
261    Err: From<Infallible>,
262{
263    type Item = Result<String, Err>;
264    fn next(&mut self) -> Option<Self::Item> {
265        self.0.next()
266    }
267}
268
269mod string_iter {
270    use crate::context::Context;
271    use crate::util::InMemoryOutput;
272    use crate::{SequenceAccept, Writable, WritableSeq};
273    use std::cell::UnsafeCell;
274    use std::convert::Infallible;
275    use std::fmt::Debug;
276    use std::future::poll_fn;
277    use std::marker::PhantomData;
278    use std::mem::{ManuallyDrop, MaybeUninit};
279    use std::ops::DerefMut;
280    use std::pin::Pin;
281    use std::ptr::NonNull;
282    use std::task::{Poll, Waker};
283    use std::{mem, ptr};
284
285    pub struct StringIter<Ctx, Seq, Err> {
286        marker: PhantomData<(Ctx, Seq)>,
288        inner: NonNull<Progressor<Ctx, Seq, Err>>,
289        seq_error_in_pipe: Option<Err>,
292        finished: bool,
293    }
294
295    impl<Ctx, Seq, Err> Debug for StringIter<Ctx, Seq, Err>
296    where
297        Err: Debug,
298    {
299        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300            let inner = self.inner.as_ptr();
301            let buffer = unsafe {
302                &*(&raw const (*inner).buffer)
305            };
306            f.debug_struct("StringIter")
307                .field(
308                    "marker",
309                    &(&std::any::type_name::<Ctx>(), &std::any::type_name::<Seq>()),
310                )
311                .field("inner.buffer", buffer)
312                .field("seq_error_in_pipe", &self.seq_error_in_pipe)
313                .field("finished", &self.finished)
314                .finish()
315        }
316    }
317
318    impl<Ctx, Seq, Err> Drop for StringIter<Ctx, Seq, Err> {
319        fn drop(&mut self) {
320            Progressor::deallocate(self.inner)
321        }
322    }
323
324    type Progressor<Ctx, Seq, Err> =
325        RawProgressor<Ctx, Seq, Err, dyn Future<Output = Result<(), Err>>>;
326
327    struct RawProgressor<Ctx, Seq, Err, Fut: ?Sized> {
331        buffer: ItemBuffer<Err>,
333        vault: RawProgressorVault<Ctx, Seq, Err>,
336        future: ManuallyDrop<Fut>,
344    }
345
346    struct RawProgressorVault<Ctx, Seq, Err> {
347        acceptor: SeqAccept<Ctx, Err>,
349        sequence: Seq,
351    }
352
353    impl<Ctx, Seq, Err> StringIter<Ctx, Seq, Err>
354    where
355        Ctx: Context,
356        Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
357        Err: From<Infallible>,
358    {
359        pub fn new(context: Ctx, sequence: Seq) -> Self {
360            let ptr = Self::make_raw_progressor(context, sequence, |vault| {
361                WritableSeq::for_each(&vault.sequence, &mut vault.acceptor)
362            });
363            Self {
364                marker: PhantomData,
365                inner: ptr,
366                seq_error_in_pipe: None,
367                finished: false,
368            }
369        }
370
371        fn make_raw_progressor<'f, MakeFut, Fut>(
381            context: Ctx,
382            sequence: Seq,
383            make_fut: MakeFut,
384        ) -> NonNull<Progressor<Ctx, Seq, Err>>
385        where
386            Fut: Future<Output = Result<(), Err>> + 'f,
387            MakeFut: FnOnce(&'f mut RawProgressorVault<Ctx, Seq, Err>) -> Fut,
388            Ctx: 'f,
389            Seq: 'f,
390            Err: 'f,
391        {
392            let allocated = Box::new(MaybeUninit::<RawProgressor<Ctx, Seq, Err, Fut>>::uninit());
395            unsafe {
396                let fields_ptr = Box::into_raw(allocated);
400                let fields_ptr = (&mut *fields_ptr).as_mut_ptr();
401
402                let buffer_ptr = &raw mut (*fields_ptr).buffer;
404                ptr::write(buffer_ptr, ItemBuffer::default());
405
406                let vault_ptr = &raw mut (*fields_ptr).vault;
409                ptr::write(
410                    vault_ptr,
411                    RawProgressorVault {
412                        acceptor: SeqAccept {
413                            output: InMemoryOutput::new(context),
414                            buffer: buffer_ptr,
415                        },
416                        sequence,
417                    },
418                );
419
420                let future_ptr = &raw mut (*fields_ptr).future;
422                ptr::write(future_ptr, ManuallyDrop::new(make_fut(&mut *vault_ptr)));
423
424                NonNull::new_unchecked(fields_ptr as *mut Progressor<Ctx, Seq, Err>)
426            }
427        }
428    }
429
430    impl<Ctx, Seq, Err> Progressor<Ctx, Seq, Err> {
431        fn deallocate(ptr: NonNull<Self>) {
432            let ptr = ptr.as_ptr();
433            unsafe {
434                {
437                    let future_ptr = &raw mut (*ptr).future;
438                    let future_ref = (&mut *future_ptr).deref_mut();
439                    ptr::drop_in_place(future_ref);
440                }
441                let _allocation = Box::from_raw(ptr);
442
443                }
458        }
459    }
460
461    impl<Ctx, Seq, Err> Iterator for StringIter<Ctx, Seq, Err> {
462        type Item = Result<String, Err>;
463        fn next(&mut self) -> Option<Self::Item> {
464            if self.finished {
465                return None;
466            }
467            if let Some(error) = mem::take(&mut self.seq_error_in_pipe) {
469                self.finished = true;
470                return Some(Err(error));
471            }
472            let fields_ptr = self.inner.as_ptr();
473            let (poll_outcome, item) = unsafe {
474                let future_ptr = &raw mut (*fields_ptr).future;
478                let pinned_future = Pin::new_unchecked((&mut *future_ptr).deref_mut());
479
480                let poll_outcome =
482                    pinned_future.poll(&mut std::task::Context::from_waker(Waker::noop()));
483
484                let buffer_ptr = &raw const (*fields_ptr).buffer;
486                let item = (&*buffer_ptr).extract();
487
488                (poll_outcome, item)
489            };
490            match poll_outcome {
491                Poll::Pending => {
492                    assert!(
495                        item.is_some(),
496                        "Extraneous async computations (writable should complete regularly)"
497                    );
498                }
499                Poll::Ready(Err(seq_error)) => {
500                    if item.is_some() {
503                        self.seq_error_in_pipe = Some(seq_error);
506                    } else {
507                        self.finished = true;
509                        return Some(Err(seq_error));
510                    }
511                }
512                Poll::Ready(Ok(())) => {
513                    self.finished = true;
515                    }
517            };
518            item
519        }
520    }
521
522    struct SeqAccept<Ctx, Err> {
523        output: InMemoryOutput<Ctx, Err>,
524        buffer: *const ItemBuffer<Err>,
525    }
526
527    impl<Ctx, Err> SequenceAccept<InMemoryOutput<Ctx, Err>> for SeqAccept<Ctx, Err>
528    where
529        Ctx: Context,
530        Err: From<Infallible>,
531    {
532        async fn accept<W>(&mut self, writable: &W) -> Result<(), Err>
533        where
534            W: Writable<InMemoryOutput<Ctx, Err>>,
535        {
536            poll_fn(|_| {
537                let buffer = unsafe {
538                    &*self.buffer
541                };
542                if !buffer.has_space() {
543                    return Poll::Pending;
544                }
545                let result = self.output.print_output_impl(writable);
546                let string = mem::take(&mut self.output.buf);
547                buffer.set_new(result.map(|()| string));
548                Poll::Ready(Ok(()))
549            })
550            .await
551        }
552    }
553
554    struct ItemBuffer<Err> {
555        current: UnsafeCell<Option<Result<String, Err>>>,
556    }
557
558    impl<Err> Default for ItemBuffer<Err> {
559        fn default() -> Self {
560            Self {
561                current: UnsafeCell::new(None),
562            }
563        }
564    }
565
566    impl<Err> Debug for ItemBuffer<Err>
567    where
568        Err: Debug,
569    {
570        fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
571            unsafe {
572                let current = &*self.current.get();
575                f.debug_struct("ItemBuffer")
576                    .field("current", ¤t)
577                    .finish()
578            }
579        }
580    }
581
582    impl<Err> ItemBuffer<Err> {
583        fn has_space(&self) -> bool {
584            unsafe {
585                let ptr = self.current.get();
588                (&*ptr).is_none()
589            }
590        }
591
592        fn set_new(&self, value: Result<String, Err>) {
593            unsafe {
594                let ptr = self.current.get();
597                ptr::write(ptr, Some(value));
599            }
600        }
601
602        fn extract(&self) -> Option<Result<String, Err>> {
603            unsafe {
604                let ptr = self.current.get();
607                mem::replace(&mut *ptr, None)
608            }
609        }
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use crate::common::{CombinedSeq, NoOpSeq, SingularSeq, Str, StrArrSeq};
616    use crate::context::EmptyContext;
617    use crate::util::IntoStringIter;
618    use crate::{Output, SequenceAccept, Writable, WritableSeq};
619    use std::convert::Infallible;
620
621    #[test]
622    fn sequence_iterator() {
623        let sequence = StrArrSeq(&["One", "Two", "Three"]);
624        let iterator = IntoStringIter::new(EmptyContext, sequence);
625        let iterator = iterator.into_iter();
626        let expected = &["One", "Two", "Three"].map(|s| Ok::<_, Infallible>(String::from(s)));
627        assert_eq!(iterator.collect::<Vec<_>>(), Vec::from(expected));
628    }
629
630    #[test]
631    fn sequence_iterator_empty() {
632        let sequence = NoOpSeq;
633        let iterator: IntoStringIter<_, _> = IntoStringIter::new(EmptyContext, sequence);
634        let iterator = iterator.into_iter();
635        assert!(iterator.collect::<Vec<_>>().is_empty());
636    }
637
638    #[derive(Clone)]
639    struct SequenceWithError<Seq> {
640        emit_before: bool,
641        seq: Seq,
642    }
643
644    #[derive(Clone, Debug, PartialEq, Eq)]
645    struct SampleError;
646
647    impl From<Infallible> for SampleError {
649        fn from(value: Infallible) -> Self {
650            match value {}
651        }
652    }
653
654    impl<O, Seq> WritableSeq<O> for SequenceWithError<Seq>
655    where
656        O: Output<Error = SampleError>,
657        Seq: WritableSeq<O>,
658    {
659        async fn for_each<S>(&self, sink: &mut S) -> Result<(), O::Error>
660        where
661            S: SequenceAccept<O>,
662        {
663            if !self.emit_before {
664                self.seq.for_each(sink).await?;
665            }
666            Err(SampleError)
667        }
668    }
669
670    #[test]
671    fn sequence_iterator_seq_error() {
672        let sequence = SequenceWithError {
673            emit_before: true,
674            seq: StrArrSeq(&["Will", "Never", "Be", "Seen"]),
675        };
676        let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
677        assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
678        assert!(iterator.into_iter().find(Result::is_ok).is_none());
679    }
680
681    #[test]
682    fn sequence_iterator_seq_error_afterward() {
683        let sequence = SequenceWithError {
684            emit_before: false,
685            seq: StrArrSeq(&["Data", "More"]),
686        };
687        let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
688        assert_eq!(
689            vec![
690                Ok(String::from("Data")),
691                Ok(String::from("More")),
692                Err(SampleError)
693            ],
694            iterator.into_iter().collect::<Vec<_>>()
695        );
696    }
697
698    #[test]
699    fn sequence_iterator_seq_error_in_between() {
700        let sequence = CombinedSeq(
701            StrArrSeq(&["One", "Two"]),
702            SequenceWithError {
703                emit_before: true,
704                seq: SingularSeq(Str("Final")),
705            },
706        );
707        let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
708        assert_eq!(
709            vec![
710                Ok(String::from("One")),
711                Ok(String::from("Two")),
712                Err(SampleError)
713            ],
714            iterator.into_iter().collect::<Vec<_>>()
715        );
716    }
717
718    #[test]
719    fn sequence_iterator_seq_error_empty() {
720        let sequence = SequenceWithError {
721            emit_before: true,
722            seq: NoOpSeq,
723        };
724        let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
725        assert_eq!(
726            vec![Err(SampleError)],
727            iterator.into_iter().collect::<Vec<_>>()
728        );
729    }
730
731    #[derive(Clone, Debug)]
732    struct ProduceError;
733
734    impl<O> Writable<O> for ProduceError
735    where
736        O: Output<Error = SampleError>,
737    {
738        async fn write_to(&self, _: &mut O) -> Result<(), O::Error> {
739            Err(SampleError)
740        }
741    }
742
743    #[test]
744    fn sequence_iterator_write_error() {
745        let sequence = CombinedSeq(SingularSeq(ProduceError), StrArrSeq(&["Is", "Seen"]));
746        let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
747        assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
748        assert_eq!(
749            vec![
750                Err(SampleError),
751                Ok(String::from("Is")),
752                Ok(String::from("Seen")),
753            ],
754            iterator.into_iter().collect::<Vec<_>>()
755        );
756    }
757
758    #[test]
759    fn sequence_iterator_write_error_afterward() {
760        let sequence = CombinedSeq(StrArrSeq(&["Data", "MoreData"]), SingularSeq(ProduceError));
761        let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
762        assert_eq!(
763            vec![
764                Ok(String::from("Data")),
765                Ok(String::from("MoreData")),
766                Err(SampleError)
767            ],
768            iterator.into_iter().collect::<Vec<_>>()
769        );
770    }
771
772    #[test]
773    fn sequence_iterator_write_error_in_between() {
774        let sequence = CombinedSeq(
775            StrArrSeq(&["Data", "Adjacent"]),
776            CombinedSeq(SingularSeq(ProduceError), SingularSeq(Str("Final"))),
777        );
778        let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
779        assert_eq!(
780            vec![
781                Ok(String::from("Data")),
782                Ok(String::from("Adjacent")),
783                Err(SampleError),
784                Ok(String::from("Final"))
785            ],
786            iterator.into_iter().collect::<Vec<_>>()
787        );
788    }
789}