1use crate::common::SequenceConfig;
18use crate::context::Context;
19use crate::{IoOutput, Output, Writable, WritableSeq};
20use std::convert::Infallible;
21use std::fmt::Debug;
22use std::marker::PhantomData;
23use std::pin::pin;
24use std::task::{Poll, Waker};
25
26pub struct WritableFromFunction<F>(pub F);
29
30impl<F, T, W, O> SequenceConfig<T, O> for WritableFromFunction<F>
31where
32 O: Output,
33 F: Fn(&T) -> W,
34 W: Writable<O>,
35{
36 async fn write_datum(&self, datum: &T, output: &mut O) -> Result<(), O::Error> {
37 let writable = (&self.0)(datum);
38 writable.write_to(output).await
39 }
40}
41
42pub struct InMemoryIo<'s>(pub &'s mut String);
45
46impl IoOutput for InMemoryIo<'_> {
47 type Error = Infallible;
48
49 async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
50 self.0.push_str(value);
51 Ok(())
52 }
53}
54
55pub struct InMemoryOutput<Ctx, Err = Infallible> {
65 buf: String,
66 context: Ctx,
67 error_type: PhantomData<fn(Infallible) -> Err>,
68}
69
70impl<Ctx, Err> InMemoryOutput<Ctx, Err> {
71 pub fn new(context: Ctx) -> Self {
72 Self {
73 buf: String::new(),
74 context,
75 error_type: PhantomData,
76 }
77 }
78}
79
80impl<Ctx, Err> Output for InMemoryOutput<Ctx, Err>
81where
82 Ctx: Context,
83 Err: From<Infallible>,
84{
85 type Io<'b>
86 = InMemoryIo<'b>
87 where
88 Self: 'b;
89 type Ctx = Ctx;
90 type Error = Err;
91
92 async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
93 self.buf.push_str(value);
94 Ok(())
95 }
96
97 fn split(&mut self) -> (Self::Io<'_>, &Self::Ctx) {
98 (InMemoryIo(&mut self.buf), &self.context)
99 }
100
101 fn context(&self) -> &Self::Ctx {
102 &self.context
103 }
104}
105
106impl<Ctx, Err> InMemoryOutput<Ctx, Err>
107where
108 Ctx: Context,
109 Err: From<Infallible>,
110{
111 pub fn try_print_output<W>(context: Ctx, writable: &W) -> Result<String, Err>
120 where
121 W: Writable<Self>,
122 {
123 let mut output = Self::new(context);
124 let result = output.print_output_impl(writable);
125 result.map(|()| output.buf)
126 }
127
128 fn print_output_impl<W>(&mut self, writable: &W) -> Result<(), Err>
129 where
130 W: Writable<Self>,
131 {
132 let future = pin!(writable.write_to(self));
133 match future.poll(&mut std::task::Context::from_waker(Waker::noop())) {
134 Poll::Pending => panic!("Expected a complete future"),
135 Poll::Ready(result) => result,
136 }
137 }
138}
139
140impl<Ctx> InMemoryOutput<Ctx>
141where
142 Ctx: Context,
143{
144 pub fn print_output<W>(context: Ctx, writable: &W) -> String
152 where
153 W: Writable<Self>,
154 {
155 Self::try_print_output(context, writable).unwrap_or_else(|e| match e {})
156 }
157}
158
159pub struct IntoStringIter<Ctx, Seq, Err = Infallible> {
187 context: Ctx,
188 sequence: Seq,
189 error_type: PhantomData<fn(Infallible) -> Err>,
190}
191
192impl<Ctx, Seq, Err> IntoStringIter<Ctx, Seq, Err> {
193 pub fn new(context: Ctx, sequence: Seq) -> Self {
194 Self {
195 context,
196 sequence,
197 error_type: PhantomData,
198 }
199 }
200}
201
202impl<Ctx, Seq, Err> Clone for IntoStringIter<Ctx, Seq, Err>
203where
204 Ctx: Clone,
205 Seq: Clone,
206{
207 fn clone(&self) -> Self {
208 Self {
209 context: self.context.clone(),
210 sequence: self.sequence.clone(),
211 error_type: PhantomData,
212 }
213 }
214}
215
216impl<Ctx, Seq, Err> Debug for IntoStringIter<Ctx, Seq, Err> where Ctx: Debug, Seq: Debug {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("IntoStringIter")
219 .field("context", &self.context)
220 .field("sequence", &self.sequence)
221 .field("error_type", &std::any::type_name::<Err>())
222 .finish()
223 }
224}
225
226impl<Ctx, Seq, Err> IntoIterator for IntoStringIter<Ctx, Seq, Err>
227where
228 Ctx: Context,
229 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
230 Err: From<Infallible>,
231{
232 type Item = Result<String, Err>;
233 type IntoIter = ToStringIter<Ctx, Seq, Err>;
234 fn into_iter(self) -> Self::IntoIter {
235 ToStringIter(string_iter::StringIter::new(self.context, self.sequence))
236 }
237}
238
239pub struct ToStringIter<Ctx, Seq, Err = Infallible>(string_iter::StringIter<Ctx, Seq, Err>);
241
242impl<Ctx, Seq, Err> Debug for ToStringIter<Ctx, Seq, Err> where Err: Debug {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 f.debug_struct("ToStringIter")
245 .field("inner", &self.0)
246 .finish()
247 }
248}
249
250impl<Ctx, Seq, Err> Iterator for ToStringIter<Ctx, Seq, Err>
251where
252 Ctx: Context,
253 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
254 Err: From<Infallible>,
255{
256 type Item = Result<String, Err>;
257 fn next(&mut self) -> Option<Self::Item> {
258 self.0.next()
259 }
260}
261
262mod string_iter {
263 use crate::context::Context;
264 use crate::util::InMemoryOutput;
265 use crate::{SequenceAccept, Writable, WritableSeq};
266 use std::cell::UnsafeCell;
267 use std::convert::Infallible;
268 use std::future::poll_fn;
269 use std::marker::PhantomData;
270 use std::mem::{ManuallyDrop, MaybeUninit};
271 use std::ops::DerefMut;
272 use std::pin::Pin;
273 use std::ptr::NonNull;
274 use std::task::{Poll, Waker};
275 use std::{mem, ptr};
276 use std::fmt::Debug;
277
278 pub struct StringIter<Ctx, Seq, Err> {
279 marker: PhantomData<(Ctx, Seq)>,
281 inner: NonNull<Progressor<Ctx, Seq, Err>>,
282 seq_error_in_pipe: Option<Err>,
285 finished: bool,
286 }
287
288 impl<Ctx, Seq, Err> Debug for StringIter<Ctx, Seq, Err> where Err: Debug {
289 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 let inner = self.inner.as_ptr();
291 let buffer = unsafe {
292 &*(&raw const (*inner).buffer)
295 };
296 f.debug_struct("StringIter")
297 .field("marker", &(&std::any::type_name::<Ctx>(), &std::any::type_name::<Seq>()))
298 .field("inner.buffer", buffer)
299 .field("seq_error_in_pipe", &self.seq_error_in_pipe)
300 .field("finished", &self.finished)
301 .finish()
302 }
303 }
304
305 impl<Ctx, Seq, Err> Drop for StringIter<Ctx, Seq, Err> {
306 fn drop(&mut self) {
307 Progressor::deallocate(self.inner)
308 }
309 }
310
311 type Progressor<Ctx, Seq, Err> =
312 RawProgressor<Ctx, Seq, Err, dyn Future<Output = Result<(), Err>>>;
313
314 struct RawProgressor<Ctx, Seq, Err, Fut: ?Sized> {
318 buffer: ItemBuffer<Err>,
320 vault: RawProgressorVault<Ctx, Seq, Err>,
323 future: ManuallyDrop<Fut>,
331 }
332
333 struct RawProgressorVault<Ctx, Seq, Err> {
334 acceptor: SeqAccept<Ctx, Err>,
336 sequence: Seq,
338 }
339
340 impl<Ctx, Seq, Err> StringIter<Ctx, Seq, Err>
341 where
342 Ctx: Context,
343 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
344 Err: From<Infallible>,
345 {
346 pub fn new(context: Ctx, sequence: Seq) -> Self {
347 let ptr = Self::make_raw_progressor(context, sequence, |vault| {
348 WritableSeq::for_each(&vault.sequence, &mut vault.acceptor)
349 });
350 Self {
351 marker: PhantomData,
352 inner: ptr,
353 seq_error_in_pipe: None,
354 finished: false,
355 }
356 }
357
358 fn make_raw_progressor<'f, MakeFut, Fut>(
368 context: Ctx,
369 sequence: Seq,
370 make_fut: MakeFut,
371 ) -> NonNull<Progressor<Ctx, Seq, Err>>
372 where
373 Fut: Future<Output = Result<(), Err>> + 'f,
374 MakeFut: FnOnce(&'f mut RawProgressorVault<Ctx, Seq, Err>) -> Fut,
375 Ctx: 'f,
376 Seq: 'f,
377 Err: 'f,
378 {
379 let allocated = Box::new(MaybeUninit::<RawProgressor<Ctx, Seq, Err, Fut>>::uninit());
382 unsafe {
383 let fields_ptr = Box::into_raw(allocated);
387 let fields_ptr = (&mut *fields_ptr).as_mut_ptr();
388
389 let buffer_ptr = &raw mut (*fields_ptr).buffer;
391 ptr::write(buffer_ptr, ItemBuffer::default());
392
393 let vault_ptr = &raw mut (*fields_ptr).vault;
396 ptr::write(
397 vault_ptr,
398 RawProgressorVault {
399 acceptor: SeqAccept {
400 output: InMemoryOutput::new(context),
401 buffer: buffer_ptr,
402 },
403 sequence,
404 },
405 );
406
407 let future_ptr = &raw mut (*fields_ptr).future;
409 ptr::write(future_ptr, ManuallyDrop::new(make_fut(&mut *vault_ptr)));
410
411 NonNull::new_unchecked(fields_ptr as *mut Progressor<Ctx, Seq, Err>)
413 }
414 }
415 }
416
417 impl<Ctx, Seq, Err> Progressor<Ctx, Seq, Err> {
418 fn deallocate(ptr: NonNull<Self>) {
419 let ptr = ptr.as_ptr();
420 unsafe {
421 {
424 let future_ptr = &raw mut (*ptr).future;
425 let future_ref = (&mut *future_ptr).deref_mut();
426 ptr::drop_in_place(future_ref);
427 }
428 let _allocation = Box::from_raw(ptr);
429
430 }
445 }
446 }
447
448 impl<Ctx, Seq, Err> Iterator for StringIter<Ctx, Seq, Err> {
449 type Item = Result<String, Err>;
450 fn next(&mut self) -> Option<Self::Item> {
451 if self.finished {
452 return None;
453 }
454 if let Some(error) = mem::take(&mut self.seq_error_in_pipe) {
456 self.finished = true;
457 return Some(Err(error));
458 }
459 let fields_ptr = self.inner.as_ptr();
460 let (poll_outcome, item) = unsafe {
461 let future_ptr = &raw mut (*fields_ptr).future;
465 let pinned_future = Pin::new_unchecked((&mut *future_ptr).deref_mut());
466
467 let poll_outcome =
469 pinned_future.poll(&mut std::task::Context::from_waker(Waker::noop()));
470
471 let buffer_ptr = &raw const (*fields_ptr).buffer;
473 let item = (&*buffer_ptr).extract();
474
475 (poll_outcome, item)
476 };
477 match poll_outcome {
478 Poll::Pending => {
479 assert!(
482 item.is_some(),
483 "Extraneous async computations (writable should complete regularly)"
484 );
485 }
486 Poll::Ready(Err(seq_error)) => {
487 if item.is_some() {
490 self.seq_error_in_pipe = Some(seq_error);
493 } else {
494 self.finished = true;
496 return Some(Err(seq_error));
497 }
498 }
499 Poll::Ready(Ok(())) => {
500 self.finished = true;
502 }
504 };
505 item
506 }
507 }
508
509 struct SeqAccept<Ctx, Err> {
510 output: InMemoryOutput<Ctx, Err>,
511 buffer: *const ItemBuffer<Err>,
512 }
513
514 impl<Ctx, Err> SequenceAccept<InMemoryOutput<Ctx, Err>> for SeqAccept<Ctx, Err>
515 where
516 Ctx: Context,
517 Err: From<Infallible>,
518 {
519 async fn accept<W>(&mut self, writable: &W) -> Result<(), Err>
520 where
521 W: Writable<InMemoryOutput<Ctx, Err>>,
522 {
523 poll_fn(|_| {
524 let buffer = unsafe {
525 &*self.buffer
528 };
529 if !buffer.has_space() {
530 return Poll::Pending;
531 }
532 let result = self.output.print_output_impl(writable);
533 let string = mem::take(&mut self.output.buf);
534 buffer.set_new(result.map(|()| string));
535 Poll::Ready(Ok(()))
536 })
537 .await
538 }
539 }
540
541 struct ItemBuffer<Err> {
542 current: UnsafeCell<Option<Result<String, Err>>>,
543 }
544
545 impl<Err> Default for ItemBuffer<Err> {
546 fn default() -> Self {
547 Self {
548 current: UnsafeCell::new(None),
549 }
550 }
551 }
552
553 impl<Err> Debug for ItemBuffer<Err> where Err: Debug {
554 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
555 unsafe {
556 let current = &*self.current.get();
559 f.debug_struct("ItemBuffer")
560 .field("current", ¤t)
561 .finish()
562 }
563 }
564 }
565
566 impl<Err> ItemBuffer<Err> {
567 fn has_space(&self) -> bool {
568 unsafe {
569 let ptr = self.current.get();
572 (&*ptr).is_none()
573 }
574 }
575
576 fn set_new(&self, value: Result<String, Err>) {
577 unsafe {
578 let ptr = self.current.get();
581 ptr::write(ptr, Some(value));
583 }
584 }
585
586 fn extract(&self) -> Option<Result<String, Err>> {
587 unsafe {
588 let ptr = self.current.get();
591 mem::replace(&mut *ptr, None)
592 }
593 }
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use crate::common::{CombinedSeq, NoOpSeq, SingularSeq, Str, StrArrSeq};
600 use crate::context::EmptyContext;
601 use crate::util::IntoStringIter;
602 use crate::{Output, SequenceAccept, Writable, WritableSeq};
603 use std::convert::Infallible;
604
605 #[test]
606 fn sequence_iterator() {
607 let sequence = StrArrSeq(&["One", "Two", "Three"]);
608 let iterator = IntoStringIter::new(EmptyContext, sequence);
609 let iterator = iterator.into_iter();
610 let expected = &["One", "Two", "Three"].map(|s| Ok::<_, Infallible>(String::from(s)));
611 assert_eq!(iterator.collect::<Vec<_>>(), Vec::from(expected));
612 }
613
614 #[test]
615 fn sequence_iterator_empty() {
616 let sequence = NoOpSeq;
617 let iterator: IntoStringIter<_, _> = IntoStringIter::new(EmptyContext, sequence);
618 let iterator = iterator.into_iter();
619 assert!(iterator.collect::<Vec<_>>().is_empty());
620 }
621
622 #[derive(Clone)]
623 struct SequenceWithError<Seq> {
624 emit_before: bool,
625 seq: Seq,
626 }
627
628 #[derive(Clone, Debug, PartialEq, Eq)]
629 struct SampleError;
630
631 impl From<Infallible> for SampleError {
633 fn from(value: Infallible) -> Self {
634 match value {}
635 }
636 }
637
638 impl<O, Seq> WritableSeq<O> for SequenceWithError<Seq>
639 where
640 O: Output<Error = SampleError>,
641 Seq: WritableSeq<O>,
642 {
643 async fn for_each<S>(&self, sink: &mut S) -> Result<(), O::Error>
644 where
645 S: SequenceAccept<O>,
646 {
647 if !self.emit_before {
648 self.seq.for_each(sink).await?;
649 }
650 Err(SampleError)
651 }
652 }
653
654 #[test]
655 fn sequence_iterator_seq_error() {
656 let sequence = SequenceWithError {
657 emit_before: true,
658 seq: StrArrSeq(&["Will", "Never", "Be", "Seen"]),
659 };
660 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
661 assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
662 assert!(iterator.into_iter().find(Result::is_ok).is_none());
663 }
664
665 #[test]
666 fn sequence_iterator_seq_error_afterward() {
667 let sequence = SequenceWithError {
668 emit_before: false,
669 seq: StrArrSeq(&["Data", "More"]),
670 };
671 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
672 assert_eq!(
673 vec![
674 Ok(String::from("Data")),
675 Ok(String::from("More")),
676 Err(SampleError)
677 ],
678 iterator.into_iter().collect::<Vec<_>>()
679 );
680 }
681
682 #[test]
683 fn sequence_iterator_seq_error_in_between() {
684 let sequence = CombinedSeq(
685 StrArrSeq(&["One", "Two"]),
686 SequenceWithError {
687 emit_before: true,
688 seq: SingularSeq(Str("Final")),
689 },
690 );
691 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
692 assert_eq!(
693 vec![
694 Ok(String::from("One")),
695 Ok(String::from("Two")),
696 Err(SampleError)
697 ],
698 iterator.into_iter().collect::<Vec<_>>()
699 );
700 }
701
702 #[test]
703 fn sequence_iterator_seq_error_empty() {
704 let sequence = SequenceWithError {
705 emit_before: true,
706 seq: NoOpSeq,
707 };
708 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
709 assert_eq!(
710 vec![Err(SampleError)],
711 iterator.into_iter().collect::<Vec<_>>()
712 );
713 }
714
715 #[derive(Clone, Debug)]
716 struct ProduceError;
717
718 impl<O> Writable<O> for ProduceError
719 where
720 O: Output<Error = SampleError>,
721 {
722 async fn write_to(&self, _: &mut O) -> Result<(), O::Error> {
723 Err(SampleError)
724 }
725 }
726
727 #[test]
728 fn sequence_iterator_write_error() {
729 let sequence = CombinedSeq(SingularSeq(ProduceError), StrArrSeq(&["Is", "Seen"]));
730 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
731 assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
732 assert_eq!(
733 vec![
734 Err(SampleError),
735 Ok(String::from("Is")),
736 Ok(String::from("Seen")),
737 ],
738 iterator.into_iter().collect::<Vec<_>>()
739 );
740 }
741
742 #[test]
743 fn sequence_iterator_write_error_afterward() {
744 let sequence = CombinedSeq(StrArrSeq(&["Data", "MoreData"]), SingularSeq(ProduceError));
745 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
746 assert_eq!(
747 vec![
748 Ok(String::from("Data")),
749 Ok(String::from("MoreData")),
750 Err(SampleError)
751 ],
752 iterator.into_iter().collect::<Vec<_>>()
753 );
754 }
755
756 #[test]
757 fn sequence_iterator_write_error_in_between() {
758 let sequence = CombinedSeq(
759 StrArrSeq(&["Data", "Adjacent"]),
760 CombinedSeq(SingularSeq(ProduceError), SingularSeq(Str("Final"))),
761 );
762 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
763 assert_eq!(
764 vec![
765 Ok(String::from("Data")),
766 Ok(String::from("Adjacent")),
767 Err(SampleError),
768 Ok(String::from("Final"))
769 ],
770 iterator.into_iter().collect::<Vec<_>>()
771 );
772 }
773}