compio_io/framed/
write.rs

1use std::{
2    io,
3    task::{Poll, ready},
4};
5
6use compio_buf::BufResult;
7use futures_util::{FutureExt, Sink};
8
9use crate::{
10    AsyncWrite, AsyncWriteExt, PinBoxFuture,
11    framed::{Framed, codec::Encoder, frame::Framer},
12};
13
14pub enum State<Io> {
15    Idle(Option<(Io, Vec<u8>)>),
16    Writing(PinBoxFuture<(Io, BufResult<(), Vec<u8>>)>),
17    Closing(PinBoxFuture<(Io, io::Result<()>, Vec<u8>)>),
18    Flushing(PinBoxFuture<(Io, io::Result<()>, Vec<u8>)>),
19}
20
21impl<Io> State<Io> {
22    pub fn new(io: Io, buf: Vec<u8>) -> Self {
23        State::Idle(Some((io, buf)))
24    }
25
26    pub fn empty() -> Self {
27        State::Idle(None)
28    }
29
30    fn take_idle(&mut self) -> (Io, Vec<u8>) {
31        match self {
32            State::Idle(idle) => idle.take().expect("Inconsistent state"),
33            _ => unreachable!("`Framed` not in idle state"),
34        }
35    }
36
37    fn buf(&mut self) -> Option<&mut Vec<u8>> {
38        match self {
39            State::Idle(Some((_, buf))) => Some(buf),
40            _ => None,
41        }
42    }
43
44    fn poll_sink(&mut self, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
45        let (io, res, buf) = match self {
46            State::Writing(fut) => {
47                let (io, BufResult(res, buf)) = ready!(fut.poll_unpin(cx));
48                (io, res, buf)
49            }
50            State::Closing(fut) | State::Flushing(fut) => ready!(fut.poll_unpin(cx)),
51            State::Idle(_) => {
52                return Poll::Ready(Ok(()));
53            }
54        };
55        *self = State::Idle(Some((io, buf)));
56        Poll::Ready(res)
57    }
58}
59
60impl<Io: AsyncWrite + 'static> State<Io> {
61    fn start_flush(&mut self) {
62        let (mut io, buf) = self.take_idle();
63        let fut = Box::pin(async move {
64            let res = io.flush().await;
65            (io, res, buf)
66        });
67        *self = State::Flushing(fut);
68    }
69
70    fn start_close(&mut self) {
71        let (mut io, buf) = self.take_idle();
72        let fut = Box::pin(async move {
73            let res = io.shutdown().await;
74            (io, res, buf)
75        });
76        *self = State::Closing(fut);
77    }
78
79    fn start_write(&mut self) {
80        let (mut io, buf) = self.take_idle();
81        let fut = Box::pin(async move {
82            let res = io.write_all(buf).await;
83            (io, res)
84        });
85        *self = State::Writing(fut);
86    }
87}
88
89impl<R, W, C, F, In, Out> Sink<In> for Framed<R, W, C, F, In, Out>
90where
91    W: AsyncWrite + 'static,
92    C: Encoder<In>,
93    F: Framer,
94    Self: Unpin,
95{
96    type Error = C::Error;
97
98    fn poll_ready(
99        self: std::pin::Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101    ) -> std::task::Poll<Result<(), Self::Error>> {
102        let this = self.get_mut();
103        match &mut this.write_state {
104            State::Idle(..) => Poll::Ready(Ok(())),
105            state => state.poll_sink(cx).map_err(C::Error::from),
106        }
107    }
108
109    fn start_send(self: std::pin::Pin<&mut Self>, item: In) -> Result<(), Self::Error> {
110        let this = self.get_mut();
111
112        let buf = this.write_state.buf().expect("`Framed` not in idle state");
113        buf.clear();
114        buf.reserve(64);
115        this.codec.encode(item, buf)?;
116        this.framer.enclose(buf);
117        this.write_state.start_write();
118
119        Ok(())
120    }
121
122    fn poll_flush(
123        self: std::pin::Pin<&mut Self>,
124        cx: &mut std::task::Context<'_>,
125    ) -> std::task::Poll<Result<(), Self::Error>> {
126        let this = self.get_mut();
127        match &mut this.write_state {
128            State::Idle(_) => {
129                this.write_state.start_flush();
130                this.write_state.poll_sink(cx).map_err(C::Error::from)
131            }
132            State::Writing(_) | State::Flushing(_) => {
133                this.write_state.poll_sink(cx).map_err(C::Error::from)
134            }
135            State::Closing(_) => unreachable!("`Framed` is closing, cannot flush"),
136        }
137    }
138
139    fn poll_close(
140        self: std::pin::Pin<&mut Self>,
141        cx: &mut std::task::Context<'_>,
142    ) -> std::task::Poll<Result<(), Self::Error>> {
143        let this = self.get_mut();
144        match &mut this.write_state {
145            state @ State::Idle(_) => {
146                state.start_close();
147                state.poll_sink(cx).map_err(C::Error::from)
148            }
149            _ => this.write_state.poll_sink(cx).map_err(C::Error::from),
150        }
151    }
152}