cdbc_pg/
copy.rs

1use cdbc::error::{Error, Result};
2use cdbc::pool::{Pool, PoolConnection};
3use crate::connection::PgConnection;
4use crate::message::{
5    CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
6};
7use crate::Postgres;
8use bytes::{BufMut, Bytes};
9use smallvec::alloc::borrow::Cow;
10use std::convert::TryFrom;
11use std::ops::{Deref, DerefMut};
12use cdbc::io::chan_stream::ChanStream;
13use std::io::Write;
14
15impl PgConnection {
16    /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data
17    /// to Postgres. This is a more efficient way to import data into Postgres as compared to
18    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
19    ///
20    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
21    /// returned.
22    ///
23    /// Command examples and accepted formats for `COPY` data are shown here:
24    /// https://www.postgresql.org/docs/current/sql-copy.html
25    ///
26    /// ### Note
27    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
28    /// will return an error the next time it is used.
29    pub fn copy_in_raw(&mut self, statement: &str) -> Result<PgCopyIn<&mut Self>> {
30        PgCopyIn::begin(self, statement)
31    }
32
33    /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data
34    /// from Postgres. This is a more efficient way to export data from Postgres but
35    /// arrives in chunks of one of a few data formats (text/CSV/binary).
36    ///
37    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
38    /// an error is returned.
39    ///
40    /// Note that once this process has begun, unless you read the stream to completion,
41    /// it can only be canceled in two ways:
42    ///
43    /// 1. by closing the connection, or:
44    /// 2. by using another connection to kill the server process that is sending the data as shown
45    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
46    ///
47    /// If you don't read the stream to completion, the next time the connection is used it will
48    /// need to read and discard all the remaining queued data, which could take some time.
49    ///
50    /// Command examples and accepted formats for `COPY` data are shown here:
51    /// https://www.postgresql.org/docs/current/sql-copy.html
52    #[allow(clippy::needless_lifetimes)]
53    pub fn copy_out_raw<'c>(
54        &'c mut self,
55        statement: &str,
56    ) -> Result<ChanStream<Bytes>> {
57        pg_begin_copy_out(self, statement)
58    }
59}
60
61pub trait CopyRaw{
62     fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>>;
63    fn copy_out_raw(&self, statement: &str) -> Result<ChanStream<Bytes>>;
64}
65
66impl CopyRaw for Pool<Postgres> {
67    /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres.
68    /// This is a more efficient way to import data into Postgres as compared to
69    /// `INSERT` but requires one of a few specific data formats (text/CSV/binary).
70    ///
71    /// A single connection will be checked out for the duration.
72    ///
73    /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is
74    /// returned.
75    ///
76    /// Command examples and accepted formats for `COPY` data are shown here:
77    /// https://www.postgresql.org/docs/current/sql-copy.html
78    ///
79    /// ### Note
80    /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
81    /// will return an error the next time it is used.
82    fn copy_in_raw(&self, statement: &str) -> Result<PgCopyIn<PoolConnection<Postgres>>> {
83        PgCopyIn::begin(self.acquire()?, statement)
84    }
85
86    /// Issue a `COPY TO STDOUT` statement and begin streaming data
87    /// from Postgres. This is a more efficient way to export data from Postgres but
88    /// arrives in chunks of one of a few data formats (text/CSV/binary).
89    ///
90    /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command,
91    /// an error is returned.
92    ///
93    /// Note that once this process has begun, unless you read the stream to completion,
94    /// it can only be canceled in two ways:
95    ///
96    /// 1. by closing the connection, or:
97    /// 2. by using another connection to kill the server process that is sending the data as shown
98    /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598).
99    ///
100    /// If you don't read the stream to completion, the next time the connection is used it will
101    /// need to read and discard all the remaining queued data, which could take some time.
102    ///
103    /// Command examples and accepted formats for `COPY` data are shown here:
104    /// https://www.postgresql.org/docs/current/sql-copy.html
105    fn copy_out_raw(&self, statement: &str) -> Result<ChanStream<Bytes>> {
106        pg_begin_copy_out(self.acquire()?, statement)
107    }
108}
109
110/// A connection in streaming `COPY FROM STDIN` mode.
111///
112/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw].
113///
114/// ### Note
115/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection
116/// will return an error the next time it is used.
117#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
118pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
119    conn: Option<C>,
120    response: CopyResponse,
121}
122
123impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
124    fn begin(mut conn: C, statement: &str) -> Result<Self> {
125        conn.wait_until_ready()?;
126        conn.stream.send(Query(statement))?;
127
128        let response: CopyResponse = conn
129            .stream
130            .recv_expect(MessageFormat::CopyInResponse)
131            ?;
132
133        Ok(PgCopyIn {
134            conn: Some(conn),
135            response,
136        })
137    }
138
139    /// Returns `true` if Postgres is expecting data in text or CSV format.
140    pub fn is_textual(&self) -> bool {
141        self.response.format == 0
142    }
143
144    /// Returns the number of columns expected in the input.
145    pub fn num_columns(&self) -> usize {
146        assert_eq!(
147            self.response.num_columns as usize,
148            self.response.format_codes.len(),
149            "num_columns does not match format_codes.len()"
150        );
151        self.response.format_codes.len()
152    }
153
154    /// Check if a column is expecting data in text format (`true`) or binary format (`false`).
155    ///
156    /// ### Panics
157    /// If `column` is out of range according to [`.num_columns()`][Self::num_columns].
158    pub fn column_is_textual(&self, column: usize) -> bool {
159        self.response.format_codes[column] == 0
160    }
161
162    /// Send a chunk of `COPY` data.
163    ///
164    /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
165    pub fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
166        self.conn
167            .as_deref_mut()
168            .expect("send_data: conn taken")
169            .stream
170            .send(CopyData(data))
171            ?;
172
173        Ok(self)
174    }
175
176    /// Copy data directly from `source` to the database without requiring an intermediate buffer.
177    ///
178    /// `source` will be read to the end.
179    ///
180    /// ### Note
181    /// You must still call either [Self::finish] or [Self::abort] to complete the process.
182    pub fn read_from(&mut self, mut source: impl std::io::Read) -> Result<&mut Self> {
183        // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing
184        struct BufGuard<'s>(&'s mut Vec<u8>);
185
186        impl Drop for BufGuard<'_> {
187            fn drop(&mut self) {
188                self.0.clear()
189            }
190        }
191
192        let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
193
194        // flush any existing messages in the buffer and clear it
195        conn.stream.flush()?;
196
197        {
198            let buf_stream = &mut *conn.stream;
199            let stream = &mut buf_stream.stream;
200
201            // ensures the buffer isn't left in an inconsistent state
202            let mut guard = BufGuard(&mut buf_stream.wbuf);
203
204            let buf: &mut Vec<u8> = &mut guard.0;
205            buf.push(b'd'); // CopyData format code
206            buf.resize(5, 0); // reserve space for the length
207
208            loop {
209                let read = match () {
210                    _ => {
211                        // should be a no-op unless len != capacity
212                        buf.resize(buf.capacity(), 0);
213                        source.read(&mut buf[5..])?
214                    }
215                };
216
217                if read == 0 {
218                    break;
219                }
220
221                let read32 = u32::try_from(read)
222                    .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?;
223
224                (&mut buf[1..]).put_u32(read32 + 4);
225
226                stream.write_all(&buf[..read + 5])?;
227                stream.flush()?;
228            }
229        }
230
231        Ok(self)
232    }
233
234    /// Signal that the `COPY` process should be aborted and any data received should be discarded.
235    ///
236    /// The given message can be used for indicating the reason for the abort in the database logs.
237    ///
238    /// The server is expected to respond with an error, so only _unexpected_ errors are returned.
239    pub fn abort(mut self, msg: impl Into<String>) -> Result<()> {
240        let mut conn = self
241            .conn
242            .take()
243            .expect("PgCopyIn::fail_with: conn taken illegally");
244
245        conn.stream.send(CopyFail::new(msg))?;
246
247        match conn.stream.recv() {
248            Ok(msg) => Err(err_protocol!(
249                "fail_with: expected ErrorResponse, got: {:?}",
250                msg.format
251            )),
252            Err(Error::Database(e)) => {
253                match e.code() {
254                    Some(Cow::Borrowed("57014")) => {
255                        // postgres abort received error code
256                        conn.stream
257                            .recv_expect(MessageFormat::ReadyForQuery)
258                            ?;
259                        Ok(())
260                    }
261                    _ => Err(Error::Database(e)),
262                }
263            }
264            Err(e) => Err(e),
265        }
266    }
267
268    /// Signal that the `COPY` process is complete.
269    ///
270    /// The number of rows affected is returned.
271    pub fn finish(mut self) -> Result<u64> {
272        let mut conn = self
273            .conn
274            .take()
275            .expect("CopyWriter::finish: conn taken illegally");
276
277        conn.stream.send(CopyDone)?;
278        let cc: CommandComplete = conn
279            .stream
280            .recv_expect(MessageFormat::CommandComplete)
281            ?;
282
283        conn.stream
284            .recv_expect(MessageFormat::ReadyForQuery)
285            ?;
286
287        Ok(cc.rows_affected())
288    }
289}
290
291impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
292    fn drop(&mut self) {
293        if let Some(mut conn) = self.conn.take() {
294            conn.stream.write(CopyFail::new(
295                "PgCopyIn dropped without calling finish() or fail()",
296            ));
297        }
298    }
299}
300
301fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
302    mut conn: C,
303    statement: &str,
304) -> Result<ChanStream<Bytes>> {
305    conn.wait_until_ready()?;
306    conn.stream.send(Query(statement))?;
307
308    let _: CopyResponse = conn
309        .stream
310        .recv_expect(MessageFormat::CopyOutResponse)
311        ?;
312
313    let stream = chan_stream! {
314        loop {
315            let msg = conn.stream.recv()?;
316            match msg.format {
317                MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
318                MessageFormat::CopyDone => {
319                    let _ = msg.decode::<CopyDone>()?;
320                    conn.stream.recv_expect(MessageFormat::CommandComplete)?;
321                    conn.stream.recv_expect(MessageFormat::ReadyForQuery)?;
322                    return Ok(())
323                },
324                _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
325            }
326        }
327    };
328
329    Ok(stream)
330}