async_codegen/
util.rs

1/*
2 * Copyright © 2025 Anand Beh
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16use crate::common::SequenceConfig;
17use crate::context::Context;
18use crate::{IoOutput, Output, Writable, WritableSeq};
19use std::convert::Infallible;
20use std::marker::PhantomData;
21use std::pin::pin;
22use std::task::{Poll, Waker};
23
24/// A wrapper intended to be used as a [SequenceConfig].
25/// The tuple value should be a function that returns a writable for the data type.
26pub struct WritableFromFunction<F>(pub F);
27
28impl<F, T, W, O> SequenceConfig<T, O> for WritableFromFunction<F>
29where
30    O: Output,
31    F: Fn(&T) -> W,
32    W: Writable<O>,
33{
34    async fn write_datum(&self, datum: &T, output: &mut O) -> Result<(), O::Error> {
35        let writable = (&self.0)(datum);
36        writable.write_to(output).await
37    }
38}
39
40/// An IO implementation that just writes to a string. The futures produced from methods
41/// like [IoOutput::write] are complete instantly, and no errors are produced.
42pub struct InMemoryIo<'s>(pub &'s mut String);
43
44impl IoOutput for InMemoryIo<'_> {
45    type Error = Infallible;
46
47    async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
48        self.0.push_str(value);
49        Ok(())
50    }
51}
52
53/// An output implementation that just writes to a string. Produces no errors and the futures
54/// complete instantly. This type is an analogue of [InMemoryIo].
55pub struct InMemoryOutput<Ctx> {
56    buf: String,
57    context: Ctx,
58}
59
60impl<Ctx> InMemoryOutput<Ctx> {
61    pub fn new(context: Ctx) -> Self {
62        Self {
63            buf: String::new(),
64            context,
65        }
66    }
67}
68
69impl<Ctx> Output for InMemoryOutput<Ctx>
70where
71    Ctx: Context,
72{
73    type Io<'b>
74        = InMemoryIo<'b>
75    where
76        Self: 'b;
77    type Ctx = Ctx;
78    type Error = Infallible;
79
80    async fn write(&mut self, value: &str) -> Result<(), Self::Error> {
81        self.buf.push_str(value);
82        Ok(())
83    }
84
85    fn split(&mut self) -> (Self::Io<'_>, &Self::Ctx) {
86        (InMemoryIo(&mut self.buf), &self.context)
87    }
88
89    fn context(&self) -> &Self::Ctx {
90        &self.context
91    }
92}
93
94impl<Ctx> InMemoryOutput<Ctx>
95where
96    Ctx: Context,
97{
98    /// Gets the string output of a single writable.
99    ///
100    /// Assumes that the only source of async is this [InMemoryOutput] itself, i.e. the writable
101    /// will never return a pending future unless the output used with it returns a pending future.
102    ///
103    /// **Panics** if the writable returns a pending future (particularly, this may happen if
104    /// a writable introduces its own async computations that do not come from the output)
105    pub fn print_output<W>(context: Ctx, writable: &W) -> String
106    where
107        W: Writable<Self>,
108    {
109        Self::print_output_impl(context, writable).buf
110    }
111
112    fn print_output_impl<W>(context: Ctx, writable: &W) -> Self
113    where
114        W: Writable<Self>,
115    {
116        let mut output = Self {
117            buf: String::new(),
118            context,
119        };
120        let result = {
121            let future = pin!(writable.write_to(&mut output));
122            match future.poll(&mut std::task::Context::from_waker(Waker::noop())) {
123                Poll::Pending => panic!("Expected a complete future"),
124                Poll::Ready(result) => result,
125            }
126        };
127        match result {
128            Ok(()) => output,
129            Err(e) => match e {}, // Unreachable
130        }
131    }
132}
133
134/// Turns a [WritableSeq] into an iterator over owned strings.
135/// This calls [InMemoryOutput::print_output] for each writable produced by the sequence. It uses
136/// Rust's async design in order to produce lazy iteration.
137///
138/// Iteration will **panic** if any writable returns a pending future (in particular, this may
139/// happen if a writable introduces its own async computations that do not come from the output)
140pub struct IntoStringIter<'f, Ctx, Seq>(
141    string_iter::IntoStringIterStart<Ctx, Seq>,
142    PhantomData<&'f ()>,
143);
144
145impl<'f, Ctx, Seq> IntoStringIter<'f, Ctx, Seq> {
146    pub fn new(context: Ctx, sequence: Seq) -> Self {
147        Self(
148            string_iter::IntoStringIterStart { context, sequence },
149            PhantomData,
150        )
151    }
152}
153
154impl<'f, Ctx, Seq> IntoIterator for IntoStringIter<'f, Ctx, Seq>
155where
156    Ctx: Context + 'f,
157    Seq: WritableSeq<InMemoryOutput<Ctx>> + 'f,
158{
159    type Item = String;
160    type IntoIter = ToStringIter<'f, Ctx, Seq>;
161    fn into_iter(self) -> Self::IntoIter {
162        ToStringIter {
163            marker: PhantomData,
164            progressor: string_iter::Progressor::from(self.0),
165        }
166    }
167}
168
169pub struct ToStringIter<'f, Ctx, Seq> {
170    marker: PhantomData<(Ctx, Seq)>,
171    progressor: string_iter::Progressor<'f>,
172}
173
174impl<'f, Ctx, Seq> Iterator for ToStringIter<'f, Ctx, Seq>
175where
176    Ctx: Context,
177    Seq: WritableSeq<InMemoryOutput<Ctx>>,
178{
179    type Item = String;
180    fn next(&mut self) -> Option<Self::Item> {
181        self.progressor.next()
182    }
183}
184
185mod string_iter {
186    use crate::util::InMemoryOutput;
187    use crate::{Output, SequenceAccept, Writable, WritableSeq};
188    use std::cell::UnsafeCell;
189    use std::convert::Infallible;
190    use std::future::poll_fn;
191    use std::marker::PhantomData;
192    use std::pin::Pin;
193    use std::ptr::NonNull;
194    use std::task::{Context, Poll, Waker};
195    use std::{mem, ptr};
196
197    struct AdaptSeq<O, Seq>(PhantomData<O>, Seq);
198
199    impl<O, Seq> AdaptSeq<O, Seq>
200    where
201        O: Output<Error = Infallible>,
202        Seq: WritableSeq<O>,
203    {
204        fn make_future<S>(self, sink: S) -> impl Future<Output = ()> + use<O, Seq, S>
205        where
206            S: SequenceAccept<O>,
207        {
208            async move {
209                let seq = self.1;
210                let mut sink = sink;
211                let res = WritableSeq::for_each(&seq, &mut sink).await;
212                match res {
213                    Ok(()) => (),
214                    Err(e) => match e {},
215                }
216            }
217        }
218    }
219
220    pub struct IntoStringIterStart<Ctx, Seq> {
221        pub context: Ctx,
222        pub sequence: Seq,
223    }
224
225    pub struct Progressor<'f> {
226        // We use a manual pointer to share data in multiple places
227        buffer: NonNull<ItemBuffer>,
228        future: Pin<Box<dyn Future<Output = ()> + 'f>>,
229        finished: bool,
230    }
231
232    impl<'f> Drop for Progressor<'f> {
233        fn drop(&mut self) {
234            unsafe {
235                // SAFETY
236                // We are not dropping the shared data anywhere else
237                ptr::drop_in_place(self.buffer.as_ptr());
238            }
239        }
240    }
241
242    impl<'f, Ctx, Seq> From<IntoStringIterStart<Ctx, Seq>> for Progressor<'f>
243    where
244        Ctx: super::Context + 'f,
245        Seq: WritableSeq<InMemoryOutput<Ctx>> + 'f,
246    {
247        fn from(value: IntoStringIterStart<Ctx, Seq>) -> Self {
248            let buffer = Box::new(ItemBuffer::default());
249            let buffer = unsafe {
250                // SAFETY -- one of Rust's API mistakes was not returning NonNull
251                NonNull::new_unchecked(Box::into_raw(buffer))
252            };
253            let adapt_seq = AdaptSeq(PhantomData, value.sequence);
254            let seq_accept = SeqAccept {
255                context: Some(value.context),
256                buffer,
257            };
258            let future = adapt_seq.make_future(seq_accept);
259            Self {
260                buffer,
261                future: Box::pin(future),
262                finished: false,
263            }
264        }
265    }
266
267    impl<'f> Iterator for Progressor<'f> {
268        type Item = String;
269        fn next(&mut self) -> Option<Self::Item> {
270            if self.finished {
271                return None;
272            }
273            let poll = self
274                .future
275                .as_mut()
276                .poll(&mut Context::from_waker(Waker::noop()));
277
278            let buffer = unsafe {
279                // SAFETY
280                // We touch this buffer in one other place, inside the future
281                &*self.buffer.as_ptr()
282            };
283            let extract = buffer.extract();
284
285            match poll {
286                Poll::Pending => {
287                    // The future stalling tells us that there is an item in the buffer
288                    // So, this should always be Some
289                    assert!(
290                        extract.is_some(),
291                        "Extraneous async computations (writable should complete regularly)"
292                    );
293                    extract
294                }
295                Poll::Ready(()) => {
296                    // All done!
297                    self.finished = true;
298                    // Can be None if we are fully empty
299                    extract
300                }
301            }
302        }
303    }
304
305    struct ItemBuffer {
306        current: UnsafeCell<Option<String>>,
307    }
308
309    impl Default for ItemBuffer {
310        fn default() -> Self {
311            Self {
312                current: UnsafeCell::new(None),
313            }
314        }
315    }
316
317    impl ItemBuffer {
318        fn has_space(&self) -> bool {
319            unsafe {
320                // SAFETY
321                // We only read and don't modify from anywhere else
322                let ptr = self.current.get();
323                (&*ptr).is_none()
324            }
325        }
326
327        fn set_new(&self, value: String) {
328            unsafe {
329                // SAFETY
330                // We don't modify from anywhere else
331                let ptr = self.current.get();
332                *ptr = Some(value);
333            }
334        }
335
336        fn extract(&self) -> Option<String> {
337            unsafe {
338                // SAFETY
339                // We don't modify from anywhere else
340                let ptr = self.current.get();
341                mem::replace(&mut *ptr, None)
342            }
343        }
344    }
345
346    struct SeqAccept<Ctx> {
347        // This is always Some() except during a panic
348        context: Option<Ctx>,
349        buffer: NonNull<ItemBuffer>,
350    }
351
352    impl<Ctx> SequenceAccept<InMemoryOutput<Ctx>> for SeqAccept<Ctx>
353    where
354        Ctx: super::Context,
355    {
356        async fn accept<W>(&mut self, writable: &W) -> Result<(), Infallible>
357        where
358            W: Writable<InMemoryOutput<Ctx>>,
359        {
360            poll_fn(|_| {
361                let buffer = unsafe {
362                    // SAFETY
363                    // We touch this buffer in one other place, outside the future
364                    &*self.buffer.as_ptr()
365                };
366                if !buffer.has_space() {
367                    return Poll::Pending;
368                }
369                unsafe {
370                    // SAFETY
371                    // We take the Ctx, then put it back, without modifying it
372                    let ctx_store = &mut self.context;
373                    let context = mem::take(ctx_store).unwrap_unchecked();
374                    // Panic safety: upon panic, Ctx is dropped and the None prevents double-free
375                    let InMemoryOutput {
376                        buf: string,
377                        context,
378                    } = InMemoryOutput::print_output_impl(context, writable);
379                    *ctx_store = Some(context);
380
381                    // Then, we put the string into the storage
382                    buffer.set_new(string);
383                }
384                Poll::Ready(Ok(()))
385            })
386            .await
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use crate::common::{NoOpSeq, StrArrSeq};
394    use crate::context::EmptyContext;
395    use crate::util::IntoStringIter;
396
397    #[test]
398    fn sequence_iterator() {
399        let sequence = StrArrSeq(&["One", "Two", "Three"]);
400        let iterator = IntoStringIter::new(EmptyContext, sequence);
401        let iterator = iterator.into_iter();
402        let expected = &["One", "Two", "Three"].map(String::from);
403        assert_eq!(iterator.collect::<Vec<_>>(), Vec::from(expected));
404    }
405
406    #[test]
407    fn sequence_iterator_empty() {
408        let sequence = NoOpSeq;
409        let iterator = IntoStringIter::new(EmptyContext, sequence);
410        let iterator = iterator.into_iter();
411        assert!(iterator.collect::<Vec<_>>().is_empty());
412    }
413}