1use core::fmt;
2use std::{
3 cell::RefCell,
4 future::{poll_fn, Future},
5 panic::{catch_unwind, AssertUnwindSafe},
6 pin::Pin,
7 sync::Arc,
8 task::{ready, Context, Poll, Waker},
9};
10
11use futures_core::Stream;
12use pin_project_lite::pin_project;
13use serde::Serialize;
14
15use crate::{DynOutput, ProcedureError};
16
17thread_local! {
18 static CAN_FLUSH: RefCell<bool> = RefCell::default();
19 static SHOULD_FLUSH: RefCell<Option<bool>> = RefCell::default();
20}
21
22pub async fn flush() {
24 if CAN_FLUSH.with(|v| *v.borrow()) {
25 let mut pending = true;
26 poll_fn(|_| {
27 if pending {
28 pending = false;
29 SHOULD_FLUSH.replace(Some(true));
30 return Poll::Pending;
31 }
32
33 Poll::Ready(())
34 })
35 .await;
36 }
37}
38
39enum Inner {
40 Dyn(Pin<Box<dyn DynReturnValue>>),
41 Value(Option<ProcedureError>),
42}
43
44#[must_use = "`ProcedureStream` does nothing unless polled"]
46pub struct ProcedureStream {
47 inner: Inner,
48 flush: Option<Waker>,
55 pending_value: bool, }
59
60impl From<ProcedureError> for ProcedureStream {
61 fn from(err: ProcedureError) -> Self {
62 Self {
63 inner: Inner::Value(Some(err)),
64 flush: None,
65 pending_value: false,
66 }
67 }
68}
69
70impl ProcedureStream {
71 pub fn from_stream<T, S>(s: S) -> Self
73 where
74 S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
75 T: Serialize + Send + Sync + 'static,
76 {
77 Self {
78 inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
79 inner: s,
80 poll: |s, cx| s.poll_next(cx),
81 size_hint: |s| s.size_hint(),
82 resolved: |_| true,
83 as_value: |v| {
84 DynOutput::new_serialize(
85 v.as_mut()
86 .expect("unreachable")
88 .as_mut()
89 .expect("unreachable"),
91 )
92 },
93 flushed: false,
94 unwound: false,
95 value: None,
96 })),
97 flush: None,
98 pending_value: false,
99 }
100 }
101
102 pub fn from_future<T, F>(f: F) -> Self
104 where
105 F: Future<Output = Result<T, ProcedureError>> + Send + 'static,
106 T: Serialize + Send + Sync + 'static,
107 {
108 pin_project! {
109 #[project = ReprProj]
110 struct Repr<F> {
111 #[pin]
112 inner: Option<F>,
113 }
114 }
115
116 Self {
117 inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
118 inner: Repr { inner: Some(f) },
119 poll: |f, cx| {
120 let mut this = f.project();
121 let v = match this.inner.as_mut().as_pin_mut() {
122 Some(fut) => ready!(fut.poll(cx)),
123 None => return Poll::Ready(None),
124 };
125
126 this.inner.set(None);
127 Poll::Ready(Some(v))
128 },
129 size_hint: |f| {
130 if f.inner.is_some() {
131 (1, Some(1))
132 } else {
133 (0, Some(0))
134 }
135 },
136 as_value: |v| {
137 DynOutput::new_serialize(
138 v.as_mut()
139 .expect("unreachable")
141 .as_mut()
142 .expect("unreachable"),
144 )
145 },
146 resolved: |f| f.inner.is_none(),
147 flushed: false,
148 unwound: false,
149 value: None,
150 })),
151 flush: None,
152 pending_value: false,
153 }
154 }
155
156 pub fn from_future_stream<T, F, S>(f: F) -> Self
158 where
159 F: Future<Output = Result<S, ProcedureError>> + Send + 'static,
160 S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
161 T: Serialize + Send + Sync + 'static,
162 {
163 pin_project! {
164 #[project = ReprProj]
165 enum Repr<F, S> {
166 Future {
167 #[pin]
168 inner: F,
169 },
170 Stream {
171 #[pin]
172 inner: S,
173 },
174 }
175 }
176
177 Self {
178 inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
179 inner: Repr::<F, S>::Future { inner: f },
180 poll: |mut f, cx| loop {
181 let this = f.as_mut().project();
182 match this {
183 ReprProj::Future { inner } => {
184 let Poll::Ready(Ok(stream)) = inner.poll(cx) else {
185 return Poll::Pending;
186 };
187
188 f.set(Repr::Stream { inner: stream });
189 continue;
190 }
191 ReprProj::Stream { inner } => return inner.poll_next(cx),
192 }
193 },
194 size_hint: |_| (1, Some(1)),
195 resolved: |f| matches!(f, Repr::Stream { .. }),
196 as_value: |v| {
197 DynOutput::new_serialize(
198 v.as_mut()
199 .expect("unreachable")
201 .as_mut()
202 .expect("unreachable"),
204 )
205 },
206 flushed: false,
207 unwound: false,
208 value: None,
209 })),
210 flush: None,
211 pending_value: false,
212 }
213 }
214
215 pub fn from_stream_value<T, S>(s: S) -> Self
217 where
218 S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
219 T: Send + Sync + 'static,
220 {
221 Self {
222 inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
223 inner: s,
224 poll: |s, cx| s.poll_next(cx),
225 size_hint: |s| s.size_hint(),
226 resolved: |_| true,
227 as_value: |v| DynOutput::new_value(v),
229 flushed: false,
230 unwound: false,
231 value: None,
232 })),
233 flush: None,
234 pending_value: false,
235 }
236 }
237
238 pub fn from_future_value<T, F>(f: F) -> Self
240 where
241 F: Future<Output = Result<T, ProcedureError>> + Send + 'static,
242 T: Send + Sync + 'static,
243 {
244 pin_project! {
245 #[project = ReprProj]
246 struct Repr<F> {
247 #[pin]
248 inner: Option<F>,
249 }
250 }
251
252 Self {
253 inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
254 inner: Repr { inner: Some(f) },
255 poll: |f, cx| {
256 let mut this = f.project();
257 let v = match this.inner.as_mut().as_pin_mut() {
258 Some(fut) => ready!(fut.poll(cx)),
259 None => return Poll::Ready(None),
260 };
261
262 this.inner.set(None);
263 Poll::Ready(Some(v))
264 },
265 size_hint: |f| {
266 if f.inner.is_some() {
267 (1, Some(1))
268 } else {
269 (0, Some(0))
270 }
271 },
272 as_value: |v| DynOutput::new_value(v),
273 resolved: |f| f.inner.is_none(),
274 flushed: false,
275 unwound: false,
276 value: None,
277 })),
278 flush: None,
279 pending_value: false,
280 }
281 }
282
283 pub fn from_future_stream_value<T, F, S>(f: F) -> Self
285 where
286 F: Future<Output = Result<S, ProcedureError>> + Send + 'static,
287 S: Stream<Item = Result<T, ProcedureError>> + Send + 'static,
288 T: Send + Sync + 'static,
289 {
290 pin_project! {
291 #[project = ReprProj]
292 enum Repr<F, S> {
293 Future {
294 #[pin]
295 inner: F,
296 },
297 Stream {
298 #[pin]
299 inner: S,
300 },
301 }
302 }
303
304 Self {
305 inner: Inner::Dyn(Box::pin(GenericDynReturnValue {
306 inner: Repr::<F, S>::Future { inner: f },
307 poll: |mut f, cx| loop {
308 let this = f.as_mut().project();
309 match this {
310 ReprProj::Future { inner } => {
311 let Poll::Ready(Ok(stream)) = inner.poll(cx) else {
312 return Poll::Pending;
313 };
314
315 f.set(Repr::Stream { inner: stream });
316 continue;
317 }
318 ReprProj::Stream { inner } => return inner.poll_next(cx),
319 }
320 },
321 size_hint: |_| (1, Some(1)),
322 resolved: |f| matches!(f, Repr::Stream { .. }),
323 as_value: |v| DynOutput::new_value(v),
324 flushed: false,
325 unwound: false,
326 value: None,
327 })),
328 flush: None,
329 pending_value: false,
330 }
331 }
332
333 pub fn require_manual_stream(mut self) -> Self {
351 struct NoOpWaker;
353 impl std::task::Wake for NoOpWaker {
354 fn wake(self: std::sync::Arc<Self>) {}
355 }
356
357 self.flush = Some(Arc::new(NoOpWaker).into());
359 self
360 }
361
362 pub fn stream(&mut self) {
365 if let Some(waker) = self.flush.take() {
366 waker.wake();
367 }
368 }
369
370 pub fn resolved(&self) -> bool {
374 match &self.inner {
375 Inner::Dyn(stream) => stream.resolved(),
376 Inner::Value(_) => true,
377 }
378 }
379
380 pub fn flushable(&self) -> bool {
384 match &self.inner {
385 Inner::Dyn(stream) => stream.flushed(),
386 Inner::Value(_) => false,
387 }
388 }
389
390 pub fn size_hint(&self) -> (usize, Option<usize>) {
392 match &self.inner {
393 Inner::Dyn(stream) => stream.size_hint(),
394 Inner::Value(_) => (1, Some(1)),
395 }
396 }
397
398 fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
399 if let Some(waker) = &mut self.flush {
401 if !waker.will_wake(cx.waker()) {
402 self.flush = Some(cx.waker().clone());
403 }
404 }
405
406 if self.pending_value {
407 return if self.flush.is_none() {
408 self.pending_value = false;
410 Poll::Ready(Some(()))
411 } else {
412 Poll::Pending
414 };
415 }
416
417 match &mut self.inner {
418 Inner::Dyn(v) => match v.as_mut().poll_next_value(cx) {
419 Poll::Ready(v) => {
420 if self.flush.is_none() {
421 Poll::Ready(v)
422 } else {
423 match v {
424 Some(v) => {
425 self.pending_value = true;
426 Poll::Pending
427 }
428 None => Poll::Ready(None),
429 }
430 }
431 }
432 Poll::Pending => Poll::Pending,
433 },
434 Inner::Value(v) => {
435 if self.flush.is_none() {
436 todo!();
438 } else {
439 Poll::Pending
440 }
441 }
442 }
443 }
444
445 pub fn poll_next(
447 &mut self,
448 cx: &mut Context<'_>,
449 ) -> Poll<Option<Result<DynOutput<'_>, ProcedureError>>> {
450 self.poll_inner(cx).map(|v| {
451 v.map(|_: ()| {
452 let Inner::Dyn(s) = &mut self.inner else {
453 unreachable!(); };
455 s.as_mut().value()
456 })
457 })
458 }
459
460 pub async fn next(&mut self) -> Option<Result<DynOutput<'_>, ProcedureError>> {
462 poll_fn(|cx| self.poll_inner(cx)).await.map(|_: ()| {
463 let Inner::Dyn(s) = &mut self.inner else {
464 unreachable!(); };
466 s.as_mut().value()
467 })
468 }
469
470 pub fn map<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String>, T>(
473 self,
474 map: F,
475 ) -> ProcedureStreamMap<F, T> {
476 ProcedureStreamMap { stream: self, map }
477 }
478}
479
480pub struct ProcedureStreamMap<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String>, T> {
481 stream: ProcedureStream,
482 map: F,
483}
484
485impl<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String>, T> ProcedureStreamMap<F, T> {
486 pub fn stream(&mut self) {
489 self.stream.stream();
490 }
491
492 pub fn resolved(&self) -> bool {
496 self.stream.resolved()
497 }
498
499 pub fn flushable(&self) -> bool {
503 self.stream.flushable()
504 }
505}
506
507impl<F: FnMut(Result<DynOutput, ProcedureError>) -> Result<T, String> + Unpin, T> Stream
509 for ProcedureStreamMap<F, T>
510{
511 type Item = T;
512
513 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
514 let this = self.get_mut();
515
516 this.stream.poll_inner(cx).map(|v| {
517 v.map(|_: ()| {
518 let Inner::Dyn(s) = &mut this.stream.inner else {
519 unreachable!();
520 };
521
522 match (this.map)(s.as_mut().value()) {
523 Ok(v) => v,
524 Err(err) => {
527 println!("Error serialzing {err:?}");
528 todo!();
529 }
530 }
531 })
532 })
533 }
534
535 fn size_hint(&self) -> (usize, Option<usize>) {
536 self.stream.size_hint()
537 }
538}
539
540impl fmt::Debug for ProcedureStream {
541 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
542 todo!();
543 }
544}
545
546trait DynReturnValue: Send {
547 fn poll_next_value<'a>(self: Pin<&'a mut Self>, cx: &mut Context<'_>) -> Poll<Option<()>>;
548 fn value(self: Pin<&mut Self>) -> Result<DynOutput<'_>, ProcedureError>;
549 fn size_hint(&self) -> (usize, Option<usize>);
550 fn resolved(&self) -> bool;
551 fn flushed(&self) -> bool;
552}
553
554pin_project! {
555 struct GenericDynReturnValue<S, T> {
556 #[pin]
557 inner: S,
558 poll: fn(Pin<&mut S>, &mut Context) -> Poll<Option<Result<T, ProcedureError>>>,
560 size_hint: fn(&S) -> (usize, Option<usize>),
562 as_value: fn(&mut Option<Result<T, ProcedureError>>) -> DynOutput<'_>,
564 resolved: fn(&S) -> bool,
566 flushed: bool,
568 unwound: bool,
570 value: Option<Result<T, ProcedureError>>,
574 }
575}
576
577impl<S: Send, T: Send> DynReturnValue for GenericDynReturnValue<S, T> {
578 fn poll_next_value<'a>(mut self: Pin<&'a mut Self>, cx: &mut Context<'_>) -> Poll<Option<()>> {
579 if self.unwound {
580 return Poll::Ready(None);
582 }
583
584 let this = self.as_mut().project();
585 let r = catch_unwind(AssertUnwindSafe(|| {
586 let _ = this.value.take(); (this.poll)(this.inner, cx).map(|v| {
588 v.map(|v| {
589 *this.value = Some(v);
590 ()
591 })
592 })
593 }));
594
595 match r {
596 Ok(v) => v,
597 Err(err) => {
598 *this.unwound = true;
599 *this.value = Some(Err(ProcedureError::Unwind(err)));
600 Poll::Ready(Some(()))
601 }
602 }
603 }
604
605 fn value(self: Pin<&mut Self>) -> Result<DynOutput<'_>, ProcedureError> {
606 let this = self.project();
607 match this.value {
608 Some(Err(_)) => {
609 let Some(Err(err)) = std::mem::replace(this.value, None) else {
610 unreachable!(); };
612 Err(err)
613 }
614 v => Ok((this.as_value)(v)),
615 }
616 }
617
618 fn size_hint(&self) -> (usize, Option<usize>) {
619 (self.size_hint)(&self.inner)
620 }
621
622 fn resolved(&self) -> bool {
623 (self.resolved)(&self.inner)
624 }
625 fn flushed(&self) -> bool {
626 self.flushed
627 }
628}