1use crate::common::SequenceConfig;
18use crate::context::Context;
19use crate::{IoOutput, Output, Writable, WritableSeq};
20use std::convert::Infallible;
21use std::fmt::Debug;
22use std::marker::PhantomData;
23use std::pin::pin;
24use std::task::{Poll, Waker};
25
26pub struct WritableFromFunction<F>(pub F);
29
30impl<F, T, W, O> SequenceConfig<T, O> for WritableFromFunction<F>
31where
32 O: Output,
33 F: Fn(&T) -> W,
34 W: Writable<O>,
35{
36 async fn write_datum(&self, datum: &T, output: &mut O) -> Result<(), O::Error> {
37 let writable = (&self.0)(datum);
38 writable.write_to(output).await
39 }
40}
41
42pub struct InMemoryIo<'s>(pub &'s mut String);
45
46impl IoOutput for InMemoryIo<'_> {
47 type Error = Infallible;
48
49 async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
50 self.0.push_str(value);
51 Ok(())
52 }
53}
54
55pub struct InMemoryOutput<Ctx, Err = Infallible> {
65 buf: String,
66 context: Ctx,
67 error_type: PhantomData<fn(Infallible) -> Err>,
68}
69
70impl<Ctx, Err> InMemoryOutput<Ctx, Err> {
71 pub fn new(context: Ctx) -> Self {
72 Self {
73 buf: String::new(),
74 context,
75 error_type: PhantomData,
76 }
77 }
78}
79
80impl<Ctx, Err> Output for InMemoryOutput<Ctx, Err>
81where
82 Ctx: Context,
83 Err: From<Infallible>,
84{
85 type Io<'b>
86 = InMemoryIo<'b>
87 where
88 Self: 'b;
89 type Ctx = Ctx;
90 type Error = Err;
91
92 async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
93 self.buf.push_str(value);
94 Ok(())
95 }
96
97 fn split(&mut self) -> (Self::Io<'_>, &Self::Ctx) {
98 (InMemoryIo(&mut self.buf), &self.context)
99 }
100
101 fn context(&self) -> &Self::Ctx {
102 &self.context
103 }
104}
105
106impl<Ctx, Err> InMemoryOutput<Ctx, Err>
107where
108 Ctx: Context,
109 Err: From<Infallible>,
110{
111 pub fn try_print_output<W>(context: Ctx, writable: &W) -> Result<String, Err>
120 where
121 W: Writable<Self>,
122 {
123 let mut output = Self::new(context);
124 let result = output.print_output_impl(writable);
125 result.map(|()| output.buf)
126 }
127
128 fn print_output_impl<W>(&mut self, writable: &W) -> Result<(), Err>
129 where
130 W: Writable<Self>,
131 {
132 let future = pin!(writable.write_to(self));
133 match future.poll(&mut std::task::Context::from_waker(Waker::noop())) {
134 Poll::Pending => panic!("Expected a complete future"),
135 Poll::Ready(result) => result,
136 }
137 }
138}
139
140impl<Ctx> InMemoryOutput<Ctx>
141where
142 Ctx: Context,
143{
144 pub fn print_output<W>(context: Ctx, writable: &W) -> String
152 where
153 W: Writable<Self>,
154 {
155 Self::try_print_output(context, writable).unwrap_or_else(|e| match e {})
156 }
157}
158
159pub struct IntoStringIter<Ctx, Seq, Err = Infallible> {
187 context: Ctx,
188 sequence: Seq,
189 error_type: PhantomData<fn(Infallible) -> Err>,
190}
191
192impl<Ctx, Seq, Err> IntoStringIter<Ctx, Seq, Err> {
193 pub fn new(context: Ctx, sequence: Seq) -> Self {
194 Self {
195 context,
196 sequence,
197 error_type: PhantomData,
198 }
199 }
200}
201
202impl<Ctx, Seq, Err> Clone for IntoStringIter<Ctx, Seq, Err>
203where
204 Ctx: Clone,
205 Seq: Clone,
206{
207 fn clone(&self) -> Self {
208 Self {
209 context: self.context.clone(),
210 sequence: self.sequence.clone(),
211 error_type: PhantomData,
212 }
213 }
214}
215
216impl<Ctx, Seq, Err> Debug for IntoStringIter<Ctx, Seq, Err>
217where
218 Ctx: Debug,
219 Seq: Debug,
220{
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 f.debug_struct("IntoStringIter")
223 .field("context", &self.context)
224 .field("sequence", &self.sequence)
225 .field("error_type", &std::any::type_name::<Err>())
226 .finish()
227 }
228}
229
230impl<Ctx, Seq, Err> IntoIterator for IntoStringIter<Ctx, Seq, Err>
231where
232 Ctx: Context,
233 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
234 Err: From<Infallible>,
235{
236 type Item = Result<String, Err>;
237 type IntoIter = ToStringIter<Ctx, Seq, Err>;
238 fn into_iter(self) -> Self::IntoIter {
239 ToStringIter(string_iter::StringIter::new(self.context, self.sequence))
240 }
241}
242
243pub struct ToStringIter<Ctx, Seq, Err = Infallible>(string_iter::StringIter<Ctx, Seq, Err>);
245
246impl<Ctx, Seq, Err> Debug for ToStringIter<Ctx, Seq, Err>
247where
248 Err: Debug,
249{
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 f.debug_struct("ToStringIter")
252 .field("inner", &self.0)
253 .finish()
254 }
255}
256
257impl<Ctx, Seq, Err> Iterator for ToStringIter<Ctx, Seq, Err>
258where
259 Ctx: Context,
260 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
261 Err: From<Infallible>,
262{
263 type Item = Result<String, Err>;
264 fn next(&mut self) -> Option<Self::Item> {
265 self.0.next()
266 }
267}
268
269mod string_iter {
270 use crate::context::Context;
271 use crate::util::InMemoryOutput;
272 use crate::{SequenceAccept, Writable, WritableSeq};
273 use std::cell::Cell;
274 use std::convert::Infallible;
275 use std::fmt::Debug;
276 use std::future::poll_fn;
277 use std::mem::{ManuallyDrop, MaybeUninit};
278 use std::ops::DerefMut;
279 use std::pin::Pin;
280 use std::ptr::NonNull;
281 use std::task::{Poll, Waker};
282 use std::{mem, ptr};
283
284 pub struct StringIter<Ctx, Seq, Err> {
285 progressor: NonNull<Progressor<Ctx, Seq, Err>>,
286 seq_error_in_pipe: Option<Err>,
289 finished: bool,
290 }
291
292 impl<Ctx, Seq, Err> Debug for StringIter<Ctx, Seq, Err>
293 where
294 Err: Debug,
295 {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 let progressor = self.progressor.as_ptr();
298 let buffer = unsafe {
299 &*(&raw const (*progressor).buffer)
302 };
303 f.debug_struct("StringIter")
304 .field(
305 "marker",
306 &(&std::any::type_name::<Ctx>(), &std::any::type_name::<Seq>()),
307 )
308 .field("progressor.buffer", buffer)
309 .field("seq_error_in_pipe", &self.seq_error_in_pipe)
310 .field("finished", &self.finished)
311 .finish()
312 }
313 }
314
315 impl<Ctx, Seq, Err> Drop for StringIter<Ctx, Seq, Err> {
316 fn drop(&mut self) {
317 Progressor::deallocate(self.progressor)
318 }
319 }
320
321 type Progressor<Ctx, Seq, Err> =
322 RawProgressor<Ctx, Seq, Err, dyn Future<Output = Result<(), Err>>>;
323
324 struct RawProgressor<Ctx, Seq, Err, Fut: ?Sized> {
328 buffer: ItemBuffer<Err>,
330 vault: RawProgressorVault<Ctx, Seq, Err>,
333 future: ManuallyDrop<Fut>,
341 }
342
343 struct RawProgressorVault<Ctx, Seq, Err> {
344 acceptor: SeqAccept<Ctx, Err>,
346 sequence: Seq,
348 }
349
350 impl<Ctx, Seq, Err> StringIter<Ctx, Seq, Err>
351 where
352 Ctx: Context,
353 Seq: WritableSeq<InMemoryOutput<Ctx, Err>>,
354 Err: From<Infallible>,
355 {
356 pub fn new(context: Ctx, sequence: Seq) -> Self {
357 let ptr = Self::make_raw_progressor(context, sequence, |vault| {
358 WritableSeq::for_each(&vault.sequence, &mut vault.acceptor)
359 });
360 Self {
361 progressor: ptr,
362 seq_error_in_pipe: None,
363 finished: false,
364 }
365 }
366
367 fn make_raw_progressor<'f, MakeFut, Fut>(
377 context: Ctx,
378 sequence: Seq,
379 make_fut: MakeFut,
380 ) -> NonNull<Progressor<Ctx, Seq, Err>>
381 where
382 Fut: Future<Output = Result<(), Err>> + 'f,
383 MakeFut: FnOnce(&'f mut RawProgressorVault<Ctx, Seq, Err>) -> Fut,
384 Ctx: 'f,
385 Seq: 'f,
386 Err: 'f,
387 {
388 let allocated = Box::new(MaybeUninit::<RawProgressor<Ctx, Seq, Err, Fut>>::uninit());
391 unsafe {
392 let fields_ptr = Box::into_raw(allocated);
396 let fields_ptr = (&mut *fields_ptr).as_mut_ptr();
397
398 let buffer_ptr = &raw mut (*fields_ptr).buffer;
400 ptr::write(buffer_ptr, ItemBuffer::default());
401
402 let vault_ptr = &raw mut (*fields_ptr).vault;
405 ptr::write(
406 vault_ptr,
407 RawProgressorVault {
408 acceptor: SeqAccept {
409 output: InMemoryOutput::new(context),
410 buffer: buffer_ptr,
411 },
412 sequence,
413 },
414 );
415
416 let future_ptr = &raw mut (*fields_ptr).future;
418 ptr::write(future_ptr, ManuallyDrop::new(make_fut(&mut *vault_ptr)));
419
420 let fields_ptr = mem::transmute::<
423 *mut RawProgressor<_, _, _, dyn Future<Output = Result<(), Err>> + 'f>,
424 *mut RawProgressor<_, _, _, dyn Future<Output = Result<(), Err>> + 'static>,
425 >(fields_ptr);
426 NonNull::<Progressor<Ctx, Seq, Err>>::new_unchecked(fields_ptr)
427 }
428 }
429 }
430
431 impl<Ctx, Seq, Err> Progressor<Ctx, Seq, Err> {
432 fn deallocate(ptr: NonNull<Self>) {
433 let ptr = ptr.as_ptr();
434 unsafe {
435 {
438 let future_ptr = &raw mut (*ptr).future;
439 let future_to_drop = &mut *future_ptr;
440 ManuallyDrop::drop(future_to_drop);
441 }
442 let _allocation = Box::from_raw(ptr);
443 }
444 }
445 }
446
447 impl<Ctx, Seq, Err> Iterator for StringIter<Ctx, Seq, Err> {
448 type Item = Result<String, Err>;
449 fn next(&mut self) -> Option<Self::Item> {
450 if self.finished {
451 return None;
452 }
453 if let Some(error) = mem::take(&mut self.seq_error_in_pipe) {
455 self.finished = true;
456 return Some(Err(error));
457 }
458 let fields_ptr = self.progressor.as_ptr();
459 let (poll_outcome, item) = unsafe {
460 let future_ptr = &raw mut (*fields_ptr).future;
464 let pinned_future = Pin::new_unchecked((&mut *future_ptr).deref_mut());
465
466 let poll_outcome =
468 pinned_future.poll(&mut std::task::Context::from_waker(Waker::noop()));
469
470 let buffer_ptr = &raw const (*fields_ptr).buffer;
472 let item = (&*buffer_ptr).extract();
473
474 (poll_outcome, item)
475 };
476 match poll_outcome {
477 Poll::Pending => {
478 assert!(
481 item.is_some(),
482 "Extraneous async computations (writable should complete regularly)"
483 );
484 }
485 Poll::Ready(Err(seq_error)) => {
486 if item.is_some() {
489 self.seq_error_in_pipe = Some(seq_error);
492 } else {
493 self.finished = true;
495 return Some(Err(seq_error));
496 }
497 }
498 Poll::Ready(Ok(())) => {
499 self.finished = true;
501 }
503 };
504 item
505 }
506 }
507
508 struct SeqAccept<Ctx, Err> {
509 output: InMemoryOutput<Ctx, Err>,
510 buffer: *const ItemBuffer<Err>,
511 }
512
513 impl<Ctx, Err> SequenceAccept<InMemoryOutput<Ctx, Err>> for SeqAccept<Ctx, Err>
514 where
515 Ctx: Context,
516 Err: From<Infallible>,
517 {
518 async fn accept<W>(&mut self, writable: &W) -> Result<(), Err>
519 where
520 W: Writable<InMemoryOutput<Ctx, Err>>,
521 {
522 poll_fn(|_| {
523 let buffer = unsafe {
524 &*self.buffer
527 };
528 if !buffer.has_space() {
529 return Poll::Pending;
530 }
531 let result = self.output.print_output_impl(writable);
532 let string = mem::take(&mut self.output.buf);
533 buffer.set_new(result.map(|()| string));
534 Poll::Ready(Ok(()))
535 })
536 .await
537 }
538 }
539
540 struct ItemBuffer<Err>(Cell<Option<Result<String, Err>>>);
541
542 impl<Err> Default for ItemBuffer<Err> {
543 fn default() -> Self {
544 Self(Cell::new(None))
545 }
546 }
547
548 impl<Err> ItemBuffer<Err> {
549 fn inspect<F, R>(&self, op: F) -> R
550 where
551 F: FnOnce(&Option<Result<String, Err>>) -> R,
552 {
553 let current = self.0.take();
554 let result = op(¤t);
555 self.0.set(current);
556 result
557 }
558 }
559
560 impl<Err> Debug for ItemBuffer<Err>
561 where
562 Err: Debug,
563 {
564 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
565 self.inspect(|current| {
566 f.debug_struct("ItemBuffer")
567 .field("current", current)
568 .finish()
569 })
570 }
571 }
572
573 impl<Err> ItemBuffer<Err> {
574 fn has_space(&self) -> bool {
575 self.inspect(Option::is_none)
576 }
577
578 fn set_new(&self, value: Result<String, Err>) {
579 self.0.set(Some(value));
580 }
581
582 fn extract(&self) -> Option<Result<String, Err>> {
583 self.0.take()
584 }
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use crate::common::{CombinedSeq, NoOpSeq, SingularSeq, Str, StrArrSeq};
591 use crate::context::EmptyContext;
592 use crate::util::IntoStringIter;
593 use crate::{Output, SequenceAccept, Writable, WritableSeq};
594 use std::convert::Infallible;
595
596 #[test]
597 fn sequence_iterator() {
598 let sequence = StrArrSeq(&["One", "Two", "Three"]);
599 let iterator = IntoStringIter::new(EmptyContext, sequence);
600 let iterator = iterator.into_iter();
601 let expected = &["One", "Two", "Three"].map(|s| Ok::<_, Infallible>(String::from(s)));
602 assert_eq!(iterator.collect::<Vec<_>>(), Vec::from(expected));
603 }
604
605 #[test]
606 fn sequence_iterator_empty() {
607 let sequence = NoOpSeq;
608 let iterator: IntoStringIter<_, _> = IntoStringIter::new(EmptyContext, sequence);
609 let iterator = iterator.into_iter();
610 assert!(iterator.collect::<Vec<_>>().is_empty());
611 }
612
613 #[derive(Clone)]
614 struct SequenceWithError<Seq> {
615 emit_before: bool,
616 seq: Seq,
617 }
618
619 #[derive(Clone, Debug, PartialEq, Eq)]
620 struct SampleError;
621
622 impl From<Infallible> for SampleError {
624 fn from(value: Infallible) -> Self {
625 match value {}
626 }
627 }
628
629 impl<O, Seq> WritableSeq<O> for SequenceWithError<Seq>
630 where
631 O: Output<Error = SampleError>,
632 Seq: WritableSeq<O>,
633 {
634 async fn for_each<S>(&self, sink: &mut S) -> Result<(), O::Error>
635 where
636 S: SequenceAccept<O>,
637 {
638 if !self.emit_before {
639 self.seq.for_each(sink).await?;
640 }
641 Err(SampleError)
642 }
643 }
644
645 #[test]
646 fn sequence_iterator_seq_error() {
647 let sequence = SequenceWithError {
648 emit_before: true,
649 seq: StrArrSeq(&["Will", "Never", "Be", "Seen"]),
650 };
651 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
652 assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
653 assert!(iterator.into_iter().find(Result::is_ok).is_none());
654 }
655
656 #[test]
657 fn sequence_iterator_seq_error_afterward() {
658 let sequence = SequenceWithError {
659 emit_before: false,
660 seq: StrArrSeq(&["Data", "More"]),
661 };
662 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
663 assert_eq!(
664 vec![
665 Ok(String::from("Data")),
666 Ok(String::from("More")),
667 Err(SampleError)
668 ],
669 iterator.into_iter().collect::<Vec<_>>()
670 );
671 }
672
673 #[test]
674 fn sequence_iterator_seq_error_in_between() {
675 let sequence = CombinedSeq(
676 StrArrSeq(&["One", "Two"]),
677 SequenceWithError {
678 emit_before: true,
679 seq: SingularSeq(Str("Final")),
680 },
681 );
682 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
683 assert_eq!(
684 vec![
685 Ok(String::from("One")),
686 Ok(String::from("Two")),
687 Err(SampleError)
688 ],
689 iterator.into_iter().collect::<Vec<_>>()
690 );
691 }
692
693 #[test]
694 fn sequence_iterator_seq_error_empty() {
695 let sequence = SequenceWithError {
696 emit_before: true,
697 seq: NoOpSeq,
698 };
699 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
700 assert_eq!(
701 vec![Err(SampleError)],
702 iterator.into_iter().collect::<Vec<_>>()
703 );
704 }
705
706 #[derive(Clone, Debug)]
707 struct ProduceError;
708
709 impl<O> Writable<O> for ProduceError
710 where
711 O: Output<Error = SampleError>,
712 {
713 async fn write_to(&self, _: &mut O) -> Result<(), O::Error> {
714 Err(SampleError)
715 }
716 }
717
718 #[test]
719 fn sequence_iterator_write_error() {
720 let sequence = CombinedSeq(SingularSeq(ProduceError), StrArrSeq(&["Is", "Seen"]));
721 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
722 assert_eq!(Some(Err(SampleError)), iterator.clone().into_iter().next());
723 assert_eq!(
724 vec![
725 Err(SampleError),
726 Ok(String::from("Is")),
727 Ok(String::from("Seen")),
728 ],
729 iterator.into_iter().collect::<Vec<_>>()
730 );
731 }
732
733 #[test]
734 fn sequence_iterator_write_error_afterward() {
735 let sequence = CombinedSeq(StrArrSeq(&["Data", "MoreData"]), SingularSeq(ProduceError));
736 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
737 assert_eq!(
738 vec![
739 Ok(String::from("Data")),
740 Ok(String::from("MoreData")),
741 Err(SampleError)
742 ],
743 iterator.into_iter().collect::<Vec<_>>()
744 );
745 }
746
747 #[test]
748 fn sequence_iterator_write_error_in_between() {
749 let sequence = CombinedSeq(
750 StrArrSeq(&["Data", "Adjacent"]),
751 CombinedSeq(SingularSeq(ProduceError), SingularSeq(Str("Final"))),
752 );
753 let iterator = IntoStringIter::<_, _, SampleError>::new(EmptyContext, sequence);
754 assert_eq!(
755 vec![
756 Ok(String::from("Data")),
757 Ok(String::from("Adjacent")),
758 Err(SampleError),
759 Ok(String::from("Final"))
760 ],
761 iterator.into_iter().collect::<Vec<_>>()
762 );
763 }
764}