1use crate::common::SequenceConfig;
18use crate::context::Context;
19use crate::{IoOutput, Output, Writable, WritableSeq};
20use std::convert::Infallible;
21use std::fmt::Debug;
22use std::marker::PhantomData;
23use std::pin::pin;
24use std::task::{Poll, Waker};
25
26pub struct WritableFromFunction<F>(pub F);
29
30impl<F, T, W, O> SequenceConfig<T, O> for WritableFromFunction<F>
31where
32 O: Output,
33 F: Fn(&T) -> W,
34 W: Writable<O>,
35{
36 async fn write_datum(&self, datum: &T, output: &mut O) -> Result<(), O::Error> {
37 let writable = (&self.0)(datum);
38 writable.write_to(output).await
39 }
40}
41
42pub struct InMemoryIo<'s>(pub &'s mut String);
45
46impl IoOutput for InMemoryIo<'_> {
47 type Error = Infallible;
48
49 async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
50 self.0.push_str(value);
51 Ok(())
52 }
53}
54
55pub struct InMemoryOutput<Ctx, Err = Infallible> {
65 buf: String,
66 context: Ctx,
67 error_type: PhantomData<fn(Infallible) -> Err>,
68}
69
70impl<Ctx, Err> InMemoryOutput<Ctx, Err> {
71 pub fn new(context: Ctx) -> Self {
72 Self {
73 buf: String::new(),
74 context,
75 error_type: PhantomData,
76 }
77 }
78}
79
80impl<Ctx, Err> Output for InMemoryOutput<Ctx, Err>
81where
82 Ctx: Context,
83 Err: From<Infallible>,
84{
85 type Io<'b>
86 = InMemoryIo<'b>
87 where
88 Self: 'b;
89 type Ctx = Ctx;
90 type Error = Err;
91
92 async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
93 self.buf.push_str(value);
94 Ok(())
95 }
96
97 fn split(&mut self) -> (Self::Io<'_>, &Self::Ctx) {
98 (InMemoryIo(&mut self.buf), &self.context)
99 }
100
101 fn context(&self) -> &Self::Ctx {
102 &self.context
103 }
104}
105
106impl<Ctx, Err> InMemoryOutput<Ctx, Err>
107where
108 Ctx: Context,
109 Err: From<Infallible>,
110{
111 pub fn try_print_output<W>(context: Ctx, writable: &W) -> Result<String, Err>
120 where
121 W: Writable<Self>,
122 {
123 let mut output = Self::new(context);
124 let result = output.print_output_impl(writable);
125 result.map(|()| output.buf)
126 }
127
128 fn print_output_impl<W>(&mut self, writable: &W) -> Result<(), Err>
129 where
130 W: Writable<Self>,
131 {
132 let future = pin!(writable.write_to(self));
133 match future.poll(&mut std::task::Context::from_waker(Waker::noop())) {
134 Poll::Pending => panic!("Expected a complete future"),
135 Poll::Ready(result) => result,
136 }
137 }
138}
139
140impl<Ctx> InMemoryOutput<Ctx>
141where
142 Ctx: Context,
143{
144 pub fn print_output<W>(context: Ctx, writable: &W) -> String
152 where
153 W: Writable<Self>,
154 {
155 Self::try_print_output(context, writable).unwrap_or_else(|e| match e {})
156 }
157}
158
159pub struct IntoStringIter<Ctx, Seq, Err = Infallible> {
187 context: Ctx,
188 sequence: Seq,
189 error_type: PhantomData<fn(Infallible) -> Err>,
190}
191
192impl<Ctx, Seq, Err> IntoStringIter<Ctx, Seq, Err> {
193 pub fn new(context: Ctx, sequence: Seq) -> Self {
194 Self {
195 context,
196 sequence,
197 error_type: PhantomData,
198 }
199 }
200}
201
202impl<Ctx, Seq, Err> Clone for IntoStringIter<Ctx, Seq, Err>
203where
204 Ctx: Clone,
205 Seq: Clone,
206{
207 fn clone(&self) -> Self {
208 Self {
209 context: self.context.clone(),
210 sequence: self.sequence.clone(),
211 error_type: PhantomData,
212 }
213 }
214}
215
216impl<Ctx, Seq, Err> Debug for IntoStringIter<Ctx, Seq, Err>
217where
218 Ctx: Debug,
219 Seq: Debug,
220{
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 f.debug_struct("IntoStringIter")
223 .field("context", &self.context)
224 .field("sequence", &self.sequence)
225 .field("error_type", &std::any::type_name::<Err>())
226 .finish()
227 }
228}
229
230impl<Ctx, Seq, Err> IntoIterator for IntoStringIter<Ctx, Seq, Err>
231where
232 Ctx: Context,
233 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
234 Err: From<Infallible>,
235{
236 type Item = Result<String, Err>;
237 type IntoIter = ToStringIter<Ctx, Seq, Err>;
238 fn into_iter(self) -> Self::IntoIter {
239 ToStringIter(string_iter::StringIter::new(self.context, self.sequence))
240 }
241}
242
243pub struct ToStringIter<Ctx, Seq, Err = Infallible>(string_iter::StringIter<Ctx, Seq, Err>);
245
246impl<Ctx, Seq, Err> Debug for ToStringIter<Ctx, Seq, Err>
247where
248 Err: Debug,
249{
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 f.debug_struct("ToStringIter")
252 .field("inner", &self.0)
253 .finish()
254 }
255}
256
257impl<Ctx, Seq, Err> Iterator for ToStringIter<Ctx, Seq, Err>
258where
259 Ctx: Context,
260 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
261 Err: From<Infallible>,
262{
263 type Item = Result<String, Err>;
264 fn next(&mut self) -> Option<Self::Item> {
265 self.0.next()
266 }
267}
268
269mod string_iter {
270 use crate::context::Context;
271 use crate::util::InMemoryOutput;
272 use crate::{SequenceAccept, Writable, WritableSeq};
273 use std::cell::UnsafeCell;
274 use std::convert::Infallible;
275 use std::fmt::Debug;
276 use std::future::poll_fn;
277 use std::marker::PhantomData;
278 use std::mem::{ManuallyDrop, MaybeUninit};
279 use std::ops::DerefMut;
280 use std::pin::Pin;
281 use std::ptr::NonNull;
282 use std::task::{Poll, Waker};
283 use std::{mem, ptr};
284
285 pub struct StringIter<Ctx, Seq, Err> {
286 marker: PhantomData<(Ctx, Seq)>,
288 inner: NonNull<Progressor<Ctx, Seq, Err>>,
289 seq_error_in_pipe: Option<Err>,
292 finished: bool,
293 }
294
295 impl<Ctx, Seq, Err> Debug for StringIter<Ctx, Seq, Err>
296 where
297 Err: Debug,
298 {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 let inner = self.inner.as_ptr();
301 let buffer = unsafe {
302 &*(&raw const (*inner).buffer)
305 };
306 f.debug_struct("StringIter")
307 .field(
308 "marker",
309 &(&std::any::type_name::<Ctx>(), &std::any::type_name::<Seq>()),
310 )
311 .field("inner.buffer", buffer)
312 .field("seq_error_in_pipe", &self.seq_error_in_pipe)
313 .field("finished", &self.finished)
314 .finish()
315 }
316 }
317
318 impl<Ctx, Seq, Err> Drop for StringIter<Ctx, Seq, Err> {
319 fn drop(&mut self) {
320 Progressor::deallocate(self.inner)
321 }
322 }
323
324 type Progressor<Ctx, Seq, Err> =
325 RawProgressor<Ctx, Seq, Err, dyn Future<Output = Result<(), Err>>>;
326
327 struct RawProgressor<Ctx, Seq, Err, Fut: ?Sized> {
331 buffer: ItemBuffer<Err>,
333 vault: RawProgressorVault<Ctx, Seq, Err>,
336 future: ManuallyDrop<Fut>,
344 }
345
346 struct RawProgressorVault<Ctx, Seq, Err> {
347 acceptor: SeqAccept<Ctx, Err>,
349 sequence: Seq,
351 }
352
353 impl<Ctx, Seq, Err> StringIter<Ctx, Seq, Err>
354 where
355 Ctx: Context,
356 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
357 Err: From<Infallible>,
358 {
359 pub fn new(context: Ctx, sequence: Seq) -> Self {
360 let ptr = Self::make_raw_progressor(context, sequence, |vault| {
361 WritableSeq::for_each(&vault.sequence, &mut vault.acceptor)
362 });
363 Self {
364 marker: PhantomData,
365 inner: ptr,
366 seq_error_in_pipe: None,
367 finished: false,
368 }
369 }
370
371 fn make_raw_progressor<'f, MakeFut, Fut>(
381 context: Ctx,
382 sequence: Seq,
383 make_fut: MakeFut,
384 ) -> NonNull<Progressor<Ctx, Seq, Err>>
385 where
386 Fut: Future<Output = Result<(), Err>> + 'f,
387 MakeFut: FnOnce(&'f mut RawProgressorVault<Ctx, Seq, Err>) -> Fut,
388 Ctx: 'f,
389 Seq: 'f,
390 Err: 'f,
391 {
392 let allocated = Box::new(MaybeUninit::<RawProgressor<Ctx, Seq, Err, Fut>>::uninit());
395 unsafe {
396 let fields_ptr = Box::into_raw(allocated);
400 let fields_ptr = (&mut *fields_ptr).as_mut_ptr();
401
402 let buffer_ptr = &raw mut (*fields_ptr).buffer;
404 ptr::write(buffer_ptr, ItemBuffer::default());
405
406 let vault_ptr = &raw mut (*fields_ptr).vault;
409 ptr::write(
410 vault_ptr,
411 RawProgressorVault {
412 acceptor: SeqAccept {
413 output: InMemoryOutput::new(context),
414 buffer: buffer_ptr,
415 },
416 sequence,
417 },
418 );
419
420 let future_ptr = &raw mut (*fields_ptr).future;
422 ptr::write(future_ptr, ManuallyDrop::new(make_fut(&mut *vault_ptr)));
423
424 NonNull::new_unchecked(fields_ptr as *mut Progressor<Ctx, Seq, Err>)
426 }
427 }
428 }
429
430 impl<Ctx, Seq, Err> Progressor<Ctx, Seq, Err> {
431 fn deallocate(ptr: NonNull<Self>) {
432 let ptr = ptr.as_ptr();
433 unsafe {
434 {
437 let future_ptr = &raw mut (*ptr).future;
438 let future_ref = (&mut *future_ptr).deref_mut();
439 ptr::drop_in_place(future_ref);
440 }
441 let _allocation = Box::from_raw(ptr);
442
443 }
458 }
459 }
460
461 impl<Ctx, Seq, Err> Iterator for StringIter<Ctx, Seq, Err> {
462 type Item = Result<String, Err>;
463 fn next(&mut self) -> Option<Self::Item> {
464 if self.finished {
465 return None;
466 }
467 if let Some(error) = mem::take(&mut self.seq_error_in_pipe) {
469 self.finished = true;
470 return Some(Err(error));
471 }
472 let fields_ptr = self.inner.as_ptr();
473 let (poll_outcome, item) = unsafe {
474 let future_ptr = &raw mut (*fields_ptr).future;
478 let pinned_future = Pin::new_unchecked((&mut *future_ptr).deref_mut());
479
480 let poll_outcome =
482 pinned_future.poll(&mut std::task::Context::from_waker(Waker::noop()));
483
484 let buffer_ptr = &raw const (*fields_ptr).buffer;
486 let item = (&*buffer_ptr).extract();
487
488 (poll_outcome, item)
489 };
490 match poll_outcome {
491 Poll::Pending => {
492 assert!(
495 item.is_some(),
496 "Extraneous async computations (writable should complete regularly)"
497 );
498 }
499 Poll::Ready(Err(seq_error)) => {
500 if item.is_some() {
503 self.seq_error_in_pipe = Some(seq_error);
506 } else {
507 self.finished = true;
509 return Some(Err(seq_error));
510 }
511 }
512 Poll::Ready(Ok(())) => {
513 self.finished = true;
515 }
517 };
518 item
519 }
520 }
521
522 struct SeqAccept<Ctx, Err> {
523 output: InMemoryOutput<Ctx, Err>,
524 buffer: *const ItemBuffer<Err>,
525 }
526
527 impl<Ctx, Err> SequenceAccept<InMemoryOutput<Ctx, Err>> for SeqAccept<Ctx, Err>
528 where
529 Ctx: Context,
530 Err: From<Infallible>,
531 {
532 async fn accept<W>(&mut self, writable: &W) -> Result<(), Err>
533 where
534 W: Writable<InMemoryOutput<Ctx, Err>>,
535 {
536 poll_fn(|_| {
537 let buffer = unsafe {
538 &*self.buffer
541 };
542 if !buffer.has_space() {
543 return Poll::Pending;
544 }
545 let result = self.output.print_output_impl(writable);
546 let string = mem::take(&mut self.output.buf);
547 buffer.set_new(result.map(|()| string));
548 Poll::Ready(Ok(()))
549 })
550 .await
551 }
552 }
553
554 struct ItemBuffer<Err> {
555 current: UnsafeCell<Option<Result<String, Err>>>,
556 }
557
558 impl<Err> Default for ItemBuffer<Err> {
559 fn default() -> Self {
560 Self {
561 current: UnsafeCell::new(None),
562 }
563 }
564 }
565
566 impl<Err> Debug for ItemBuffer<Err>
567 where
568 Err: Debug,
569 {
570 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
571 unsafe {
572 let current = &*self.current.get();
575 f.debug_struct("ItemBuffer")
576 .field("current", ¤t)
577 .finish()
578 }
579 }
580 }
581
582 impl<Err> ItemBuffer<Err> {
583 fn has_space(&self) -> bool {
584 unsafe {
585 let ptr = self.current.get();
588 (&*ptr).is_none()
589 }
590 }
591
592 fn set_new(&self, value: Result<String, Err>) {
593 unsafe {
594 let ptr = self.current.get();
597 ptr::write(ptr, Some(value));
599 }
600 }
601
602 fn extract(&self) -> Option<Result<String, Err>> {
603 unsafe {
604 let ptr = self.current.get();
607 mem::replace(&mut *ptr, None)
608 }
609 }
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use crate::common::{CombinedSeq, NoOpSeq, SingularSeq, Str, StrArrSeq};
616 use crate::context::EmptyContext;
617 use crate::util::IntoStringIter;
618 use crate::{Output, SequenceAccept, Writable, WritableSeq};
619 use std::convert::Infallible;
620
621 #[test]
622 fn sequence_iterator() {
623 let sequence = StrArrSeq(&["One", "Two", "Three"]);
624 let iterator = IntoStringIter::new(EmptyContext, sequence);
625 let iterator = iterator.into_iter();
626 let expected = &["One", "Two", "Three"].map(|s| Ok::<_, Infallible>(String::from(s)));
627 assert_eq!(iterator.collect::<Vec<_>>(), Vec::from(expected));
628 }
629
630 #[test]
631 fn sequence_iterator_empty() {
632 let sequence = NoOpSeq;
633 let iterator: IntoStringIter<_, _> = IntoStringIter::new(EmptyContext, sequence);
634 let iterator = iterator.into_iter();
635 assert!(iterator.collect::<Vec<_>>().is_empty());
636 }
637
638 #[derive(Clone)]
639 struct SequenceWithError<Seq> {
640 emit_before: bool,
641 seq: Seq,
642 }
643
644 #[derive(Clone, Debug, PartialEq, Eq)]
645 struct SampleError;
646
647 impl From<Infallible> for SampleError {
649 fn from(value: Infallible) -> Self {
650 match value {}
651 }
652 }
653
654 impl<O, Seq> WritableSeq<O> for SequenceWithError<Seq>
655 where
656 O: Output<Error = SampleError>,
657 Seq: WritableSeq<O>,
658 {
659 async fn for_each<S>(&self, sink: &mut S) -> Result<(), O::Error>
660 where
661 S: SequenceAccept<O>,
662 {
663 if !self.emit_before {
664 self.seq.for_each(sink).await?;
665 }
666 Err(SampleError)
667 }
668 }
669
670 #[test]
671 fn sequence_iterator_seq_error() {
672 let sequence = SequenceWithError {
673 emit_before: true,
674 seq: StrArrSeq(&["Will", "Never", "Be", "Seen"]),
675 };
676 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
677 assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
678 assert!(iterator.into_iter().find(Result::is_ok).is_none());
679 }
680
681 #[test]
682 fn sequence_iterator_seq_error_afterward() {
683 let sequence = SequenceWithError {
684 emit_before: false,
685 seq: StrArrSeq(&["Data", "More"]),
686 };
687 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
688 assert_eq!(
689 vec![
690 Ok(String::from("Data")),
691 Ok(String::from("More")),
692 Err(SampleError)
693 ],
694 iterator.into_iter().collect::<Vec<_>>()
695 );
696 }
697
698 #[test]
699 fn sequence_iterator_seq_error_in_between() {
700 let sequence = CombinedSeq(
701 StrArrSeq(&["One", "Two"]),
702 SequenceWithError {
703 emit_before: true,
704 seq: SingularSeq(Str("Final")),
705 },
706 );
707 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
708 assert_eq!(
709 vec![
710 Ok(String::from("One")),
711 Ok(String::from("Two")),
712 Err(SampleError)
713 ],
714 iterator.into_iter().collect::<Vec<_>>()
715 );
716 }
717
718 #[test]
719 fn sequence_iterator_seq_error_empty() {
720 let sequence = SequenceWithError {
721 emit_before: true,
722 seq: NoOpSeq,
723 };
724 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
725 assert_eq!(
726 vec![Err(SampleError)],
727 iterator.into_iter().collect::<Vec<_>>()
728 );
729 }
730
731 #[derive(Clone, Debug)]
732 struct ProduceError;
733
734 impl<O> Writable<O> for ProduceError
735 where
736 O: Output<Error = SampleError>,
737 {
738 async fn write_to(&self, _: &mut O) -> Result<(), O::Error> {
739 Err(SampleError)
740 }
741 }
742
743 #[test]
744 fn sequence_iterator_write_error() {
745 let sequence = CombinedSeq(SingularSeq(ProduceError), StrArrSeq(&["Is", "Seen"]));
746 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
747 assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
748 assert_eq!(
749 vec![
750 Err(SampleError),
751 Ok(String::from("Is")),
752 Ok(String::from("Seen")),
753 ],
754 iterator.into_iter().collect::<Vec<_>>()
755 );
756 }
757
758 #[test]
759 fn sequence_iterator_write_error_afterward() {
760 let sequence = CombinedSeq(StrArrSeq(&["Data", "MoreData"]), SingularSeq(ProduceError));
761 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
762 assert_eq!(
763 vec![
764 Ok(String::from("Data")),
765 Ok(String::from("MoreData")),
766 Err(SampleError)
767 ],
768 iterator.into_iter().collect::<Vec<_>>()
769 );
770 }
771
772 #[test]
773 fn sequence_iterator_write_error_in_between() {
774 let sequence = CombinedSeq(
775 StrArrSeq(&["Data", "Adjacent"]),
776 CombinedSeq(SingularSeq(ProduceError), SingularSeq(Str("Final"))),
777 );
778 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
779 assert_eq!(
780 vec![
781 Ok(String::from("Data")),
782 Ok(String::from("Adjacent")),
783 Err(SampleError),
784 Ok(String::from("Final"))
785 ],
786 iterator.into_iter().collect::<Vec<_>>()
787 );
788 }
789}