madsim_tokio_postgres/
copy_in.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::{query, slice_iter, Error, Statement};
5use bytes::{Buf, BufMut, BytesMut};
6use futures::channel::mpsc;
7use futures::future;
8use futures::{ready, Sink, SinkExt, Stream, StreamExt};
9use log::debug;
10use pin_project_lite::pin_project;
11use postgres_protocol::message::backend::Message;
12use postgres_protocol::message::frontend;
13use postgres_protocol::message::frontend::CopyData;
14use std::marker::{PhantomData, PhantomPinned};
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18enum CopyInMessage {
19    Message(FrontendMessage),
20    Done,
21}
22
23pub struct CopyInReceiver {
24    receiver: mpsc::Receiver<CopyInMessage>,
25    done: bool,
26}
27
28impl CopyInReceiver {
29    fn new(receiver: mpsc::Receiver<CopyInMessage>) -> CopyInReceiver {
30        CopyInReceiver {
31            receiver,
32            done: false,
33        }
34    }
35}
36
37impl Stream for CopyInReceiver {
38    type Item = FrontendMessage;
39
40    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
41        if self.done {
42            return Poll::Ready(None);
43        }
44
45        match ready!(self.receiver.poll_next_unpin(cx)) {
46            Some(CopyInMessage::Message(message)) => Poll::Ready(Some(message)),
47            Some(CopyInMessage::Done) => {
48                self.done = true;
49                let mut buf = BytesMut::new();
50                frontend::copy_done(&mut buf);
51                frontend::sync(&mut buf);
52                Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
53            }
54            None => {
55                self.done = true;
56                let mut buf = BytesMut::new();
57                frontend::copy_fail("", &mut buf).unwrap();
58                frontend::sync(&mut buf);
59                Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
60            }
61        }
62    }
63}
64
65enum SinkState {
66    Active,
67    Closing,
68    Reading,
69}
70
71pin_project! {
72    /// A sink for `COPY ... FROM STDIN` query data.
73    ///
74    /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
75    /// not, the copy will be aborted.
76    pub struct CopyInSink<T> {
77        #[pin]
78        sender: mpsc::Sender<CopyInMessage>,
79        responses: Responses,
80        buf: BytesMut,
81        state: SinkState,
82        #[pin]
83        _p: PhantomPinned,
84        _p2: PhantomData<T>,
85    }
86}
87
88impl<T> CopyInSink<T>
89where
90    T: Buf + 'static + Send,
91{
92    /// A poll-based version of `finish`.
93    pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> {
94        loop {
95            match self.state {
96                SinkState::Active => {
97                    ready!(self.as_mut().poll_flush(cx))?;
98                    let mut this = self.as_mut().project();
99                    ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
100                    this.sender
101                        .start_send(CopyInMessage::Done)
102                        .map_err(|_| Error::closed())?;
103                    *this.state = SinkState::Closing;
104                }
105                SinkState::Closing => {
106                    let this = self.as_mut().project();
107                    ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?;
108                    *this.state = SinkState::Reading;
109                }
110                SinkState::Reading => {
111                    let this = self.as_mut().project();
112                    match ready!(this.responses.poll_next(cx))? {
113                        Message::CommandComplete(body) => {
114                            let rows = body
115                                .tag()
116                                .map_err(Error::parse)?
117                                .rsplit(' ')
118                                .next()
119                                .unwrap()
120                                .parse()
121                                .unwrap_or(0);
122                            return Poll::Ready(Ok(rows));
123                        }
124                        _ => return Poll::Ready(Err(Error::unexpected_message())),
125                    }
126                }
127            }
128        }
129    }
130
131    /// Completes the copy, returning the number of rows inserted.
132    ///
133    /// The `Sink::close` method is equivalent to `finish`, except that it does not return the
134    /// number of rows.
135    pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
136        future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
137    }
138}
139
140impl<T> Sink<T> for CopyInSink<T>
141where
142    T: Buf + 'static + Send,
143{
144    type Error = Error;
145
146    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
147        self.project()
148            .sender
149            .poll_ready(cx)
150            .map_err(|_| Error::closed())
151    }
152
153    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
154        let this = self.project();
155
156        let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
157            if this.buf.is_empty() {
158                Box::new(item)
159            } else {
160                Box::new(this.buf.split().freeze().chain(item))
161            }
162        } else {
163            this.buf.put(item);
164            if this.buf.len() > 4096 {
165                Box::new(this.buf.split().freeze())
166            } else {
167                return Ok(());
168            }
169        };
170
171        let data = CopyData::new(data).map_err(Error::encode)?;
172        this.sender
173            .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
174            .map_err(|_| Error::closed())
175    }
176
177    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
178        let mut this = self.project();
179
180        if !this.buf.is_empty() {
181            ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
182            let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
183            let data = CopyData::new(data).map_err(Error::encode)?;
184            this.sender
185                .as_mut()
186                .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
187                .map_err(|_| Error::closed())?;
188        }
189
190        this.sender.poll_flush(cx).map_err(|_| Error::closed())
191    }
192
193    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
194        self.poll_finish(cx).map_ok(|_| ())
195    }
196}
197
198pub async fn copy_in<T>(client: &InnerClient, statement: Statement) -> Result<CopyInSink<T>, Error>
199where
200    T: Buf + 'static + Send,
201{
202    debug!("executing copy in statement {}", statement.name());
203
204    let buf = query::encode(client, &statement, slice_iter(&[]))?;
205
206    let (mut sender, receiver) = mpsc::channel(1);
207    let receiver = CopyInReceiver::new(receiver);
208    let mut responses = client.send(RequestMessages::CopyIn(receiver))?;
209
210    sender
211        .send(CopyInMessage::Message(FrontendMessage::Raw(buf)))
212        .await
213        .map_err(|_| Error::closed())?;
214
215    match responses.next().await? {
216        Message::BindComplete => {}
217        _ => return Err(Error::unexpected_message()),
218    }
219
220    match responses.next().await? {
221        Message::CopyInResponse(_) => {}
222        _ => return Err(Error::unexpected_message()),
223    }
224
225    Ok(CopyInSink {
226        sender,
227        responses,
228        buf: BytesMut::new(),
229        state: SinkState::Active,
230        _p: PhantomPinned,
231        _p2: PhantomData,
232    })
233}