compio_io/framed/
write.rs

1use std::{
2    io,
3    task::{Poll, ready},
4};
5
6use compio_buf::BufResult;
7use futures_util::{FutureExt, Sink};
8
9use crate::{
10    AsyncWrite, AsyncWriteExt, PinBoxFuture,
11    framed::{Framed, codec::Encoder, frame::Framer},
12};
13
14pub enum State<Io> {
15    Idle(Option<(Io, Vec<u8>)>),
16    Writing(PinBoxFuture<(Io, BufResult<(), Vec<u8>>)>),
17    Closing(PinBoxFuture<(Io, io::Result<()>, Vec<u8>)>),
18    Flushing(PinBoxFuture<(Io, io::Result<()>, Vec<u8>)>),
19}
20
21impl<Io> State<Io> {
22    fn take_idle(&mut self) -> (Io, Vec<u8>) {
23        match self {
24            State::Idle(idle) => idle.take().expect("Inconsistent state"),
25            _ => unreachable!("`Framed` not in idle state"),
26        }
27    }
28
29    pub fn buf(&mut self) -> Option<&mut Vec<u8>> {
30        match self {
31            State::Idle(Some((_, buf))) => Some(buf),
32            _ => None,
33        }
34    }
35
36    pub fn start_flush(&mut self)
37    where
38        Io: AsyncWrite + 'static,
39    {
40        let (mut io, buf) = self.take_idle();
41        let fut = Box::pin(async move {
42            let res = io.flush().await;
43            (io, res, buf)
44        });
45        *self = State::Flushing(fut);
46    }
47
48    pub fn start_close(&mut self)
49    where
50        Io: AsyncWrite + 'static,
51    {
52        let (mut io, buf) = self.take_idle();
53        let fut = Box::pin(async move {
54            let res = io.shutdown().await;
55            (io, res, buf)
56        });
57        *self = State::Closing(fut);
58    }
59
60    pub fn start_write(&mut self)
61    where
62        Io: AsyncWrite + 'static,
63    {
64        let (mut io, buf) = self.take_idle();
65        let fut = Box::pin(async move {
66            let res = io.write_all(buf).await;
67            (io, res)
68        });
69        *self = State::Writing(fut);
70    }
71
72    /// State that may occur when `Framed` is acting as a [`Sink`].
73    pub fn poll_sink(&mut self, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
74        let (io, res, buf) = match self {
75            State::Writing(fut) => {
76                let (io, BufResult(res, buf)) = ready!(fut.poll_unpin(cx));
77                (io, res, buf)
78            }
79            State::Closing(fut) | State::Flushing(fut) => ready!(fut.poll_unpin(cx)),
80            State::Idle(_) => {
81                return Poll::Ready(Ok(()));
82            }
83        };
84        *self = State::Idle(Some((io, buf)));
85        Poll::Ready(res)
86    }
87}
88
89impl<R, W, C, F, In, Out> Sink<In> for Framed<R, W, C, F, In, Out>
90where
91    W: AsyncWrite + 'static,
92    C: Encoder<In>,
93    F: Framer,
94    Self: Unpin,
95{
96    type Error = C::Error;
97
98    fn poll_ready(
99        self: std::pin::Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101    ) -> std::task::Poll<Result<(), Self::Error>> {
102        let this = self.get_mut();
103        match &mut this.write_state {
104            State::Idle(..) => Poll::Ready(Ok(())),
105            state => state.poll_sink(cx).map_err(C::Error::from),
106        }
107    }
108
109    fn start_send(self: std::pin::Pin<&mut Self>, item: In) -> Result<(), Self::Error> {
110        let this = self.get_mut();
111
112        let buf = this.write_state.buf().expect("`Framed` not in idle state");
113        buf.clear();
114        buf.reserve(64);
115        this.codec.encode(item, buf)?;
116        this.framer.enclose(buf);
117        this.write_state.start_write();
118
119        Ok(())
120    }
121
122    fn poll_flush(
123        self: std::pin::Pin<&mut Self>,
124        cx: &mut std::task::Context<'_>,
125    ) -> std::task::Poll<Result<(), Self::Error>> {
126        let this = self.get_mut();
127        match &mut this.write_state {
128            State::Idle(_) => {
129                this.write_state.start_flush();
130                this.write_state.poll_sink(cx).map_err(C::Error::from)
131            }
132            State::Writing(_) | State::Flushing(_) => {
133                this.write_state.poll_sink(cx).map_err(C::Error::from)
134            }
135            State::Closing(_) => unreachable!("`Framed` is closing, cannot flush"),
136        }
137    }
138
139    fn poll_close(
140        self: std::pin::Pin<&mut Self>,
141        cx: &mut std::task::Context<'_>,
142    ) -> std::task::Poll<Result<(), Self::Error>> {
143        let this = self.get_mut();
144        match &mut this.write_state {
145            state @ State::Idle(_) => {
146                state.start_close();
147                state.poll_sink(cx).map_err(C::Error::from)
148            }
149            _ => this.write_state.poll_sink(cx).map_err(C::Error::from),
150        }
151    }
152}