madsim_tokio_postgres/
copy_in.rs1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::{query, slice_iter, Error, Statement};
5use bytes::{Buf, BufMut, BytesMut};
6use futures::channel::mpsc;
7use futures::future;
8use futures::{ready, Sink, SinkExt, Stream, StreamExt};
9use log::debug;
10use pin_project_lite::pin_project;
11use postgres_protocol::message::backend::Message;
12use postgres_protocol::message::frontend;
13use postgres_protocol::message::frontend::CopyData;
14use std::marker::{PhantomData, PhantomPinned};
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18enum CopyInMessage {
19 Message(FrontendMessage),
20 Done,
21}
22
23pub struct CopyInReceiver {
24 receiver: mpsc::Receiver<CopyInMessage>,
25 done: bool,
26}
27
28impl CopyInReceiver {
29 fn new(receiver: mpsc::Receiver<CopyInMessage>) -> CopyInReceiver {
30 CopyInReceiver {
31 receiver,
32 done: false,
33 }
34 }
35}
36
37impl Stream for CopyInReceiver {
38 type Item = FrontendMessage;
39
40 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
41 if self.done {
42 return Poll::Ready(None);
43 }
44
45 match ready!(self.receiver.poll_next_unpin(cx)) {
46 Some(CopyInMessage::Message(message)) => Poll::Ready(Some(message)),
47 Some(CopyInMessage::Done) => {
48 self.done = true;
49 let mut buf = BytesMut::new();
50 frontend::copy_done(&mut buf);
51 frontend::sync(&mut buf);
52 Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
53 }
54 None => {
55 self.done = true;
56 let mut buf = BytesMut::new();
57 frontend::copy_fail("", &mut buf).unwrap();
58 frontend::sync(&mut buf);
59 Poll::Ready(Some(FrontendMessage::Raw(buf.freeze())))
60 }
61 }
62 }
63}
64
65enum SinkState {
66 Active,
67 Closing,
68 Reading,
69}
70
71pin_project! {
72 pub struct CopyInSink<T> {
77 #[pin]
78 sender: mpsc::Sender<CopyInMessage>,
79 responses: Responses,
80 buf: BytesMut,
81 state: SinkState,
82 #[pin]
83 _p: PhantomPinned,
84 _p2: PhantomData<T>,
85 }
86}
87
88impl<T> CopyInSink<T>
89where
90 T: Buf + 'static + Send,
91{
92 pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> {
94 loop {
95 match self.state {
96 SinkState::Active => {
97 ready!(self.as_mut().poll_flush(cx))?;
98 let mut this = self.as_mut().project();
99 ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
100 this.sender
101 .start_send(CopyInMessage::Done)
102 .map_err(|_| Error::closed())?;
103 *this.state = SinkState::Closing;
104 }
105 SinkState::Closing => {
106 let this = self.as_mut().project();
107 ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?;
108 *this.state = SinkState::Reading;
109 }
110 SinkState::Reading => {
111 let this = self.as_mut().project();
112 match ready!(this.responses.poll_next(cx))? {
113 Message::CommandComplete(body) => {
114 let rows = body
115 .tag()
116 .map_err(Error::parse)?
117 .rsplit(' ')
118 .next()
119 .unwrap()
120 .parse()
121 .unwrap_or(0);
122 return Poll::Ready(Ok(rows));
123 }
124 _ => return Poll::Ready(Err(Error::unexpected_message())),
125 }
126 }
127 }
128 }
129 }
130
131 pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
136 future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
137 }
138}
139
140impl<T> Sink<T> for CopyInSink<T>
141where
142 T: Buf + 'static + Send,
143{
144 type Error = Error;
145
146 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
147 self.project()
148 .sender
149 .poll_ready(cx)
150 .map_err(|_| Error::closed())
151 }
152
153 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
154 let this = self.project();
155
156 let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
157 if this.buf.is_empty() {
158 Box::new(item)
159 } else {
160 Box::new(this.buf.split().freeze().chain(item))
161 }
162 } else {
163 this.buf.put(item);
164 if this.buf.len() > 4096 {
165 Box::new(this.buf.split().freeze())
166 } else {
167 return Ok(());
168 }
169 };
170
171 let data = CopyData::new(data).map_err(Error::encode)?;
172 this.sender
173 .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
174 .map_err(|_| Error::closed())
175 }
176
177 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
178 let mut this = self.project();
179
180 if !this.buf.is_empty() {
181 ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
182 let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
183 let data = CopyData::new(data).map_err(Error::encode)?;
184 this.sender
185 .as_mut()
186 .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
187 .map_err(|_| Error::closed())?;
188 }
189
190 this.sender.poll_flush(cx).map_err(|_| Error::closed())
191 }
192
193 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
194 self.poll_finish(cx).map_ok(|_| ())
195 }
196}
197
198pub async fn copy_in<T>(client: &InnerClient, statement: Statement) -> Result<CopyInSink<T>, Error>
199where
200 T: Buf + 'static + Send,
201{
202 debug!("executing copy in statement {}", statement.name());
203
204 let buf = query::encode(client, &statement, slice_iter(&[]))?;
205
206 let (mut sender, receiver) = mpsc::channel(1);
207 let receiver = CopyInReceiver::new(receiver);
208 let mut responses = client.send(RequestMessages::CopyIn(receiver))?;
209
210 sender
211 .send(CopyInMessage::Message(FrontendMessage::Raw(buf)))
212 .await
213 .map_err(|_| Error::closed())?;
214
215 match responses.next().await? {
216 Message::BindComplete => {}
217 _ => return Err(Error::unexpected_message()),
218 }
219
220 match responses.next().await? {
221 Message::CopyInResponse(_) => {}
222 _ => return Err(Error::unexpected_message()),
223 }
224
225 Ok(CopyInSink {
226 sender,
227 responses,
228 buf: BytesMut::new(),
229 state: SinkState::Active,
230 _p: PhantomPinned,
231 _p2: PhantomData,
232 })
233}